Playing with GAN

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-03-22 17:34:43 +01:00
parent aab7b423d1
commit 42fb4bd48e
16 changed files with 421 additions and 103 deletions

View File

@ -34,6 +34,8 @@ ext {
}
dependencies {
implementation platform(projects.cavisCommonPlatform)
implementation "com.fasterxml.jackson.core:jackson-databind"
implementation "com.google.guava:guava"
implementation projects.cavisDnn.cavisDnnCore
@ -52,6 +54,16 @@ dependencies {
testImplementation "org.apache.spark:spark-sql_${scalaVersion}"
testCompileOnly "org.scala-lang:scala-library"
//Rest Client
// define any required OkHttp artifacts without version
implementation("com.squareup.okhttp3:okhttp")
implementation("com.squareup.okhttp3:logging-interceptor")
implementation "org.bytedeco:javacv"
implementation "org.bytedeco:opencv"
implementation group: "org.bytedeco", name: "opencv", classifier: buildTarget
implementation "it.unimi.dsi:fastutil-core:8.5.8"
implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkCore

View File

@ -21,49 +21,90 @@
package net.brutex.gan;
import java.util.Random;
import javax.ws.rs.client.ClientBuilder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.Writable;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.datavec.image.transform.ScaleImageTransform;
import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
import org.glassfish.jersey.client.JerseyClient;
import org.glassfish.jersey.client.JerseyClientBuilder;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Arrays;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
@Slf4j
public class App {
private static final double LEARNING_RATE = 0.0002;
private static final double LEARNING_RATE = 0.000002;
private static final double GRADIENT_THRESHOLD = 100.0;
private static final int X_DIM = 28;
private static final int Y_DIM = 28;
private static final int CHANNELS = 1;
private static final int batchSize = 9;
private static final int INPUT = 128;
private static final int OUTPUT_PER_PANEL = 4;
private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
private static JFrame frame;
private static JFrame frame2;
private static JPanel panel;
private static JPanel panel2;
private static Layer[] genLayers() {
return new Layer[] {
new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
new DenseLayer.Builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DenseLayer.Builder().nIn(256).nOut(512).build(),
new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DenseLayer.Builder().nIn(512).nOut(1024).build(),
new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build()
new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH)
.build()
};
}
@ -81,6 +122,7 @@ public class App {
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.list(genLayers())
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
.build();
return conf;
@ -88,16 +130,19 @@ public class App {
private static Layer[] disLayers() {
return new Layer[]{
new DenseLayer.Builder().nIn(784).nOut(1024).build(),
new DenseLayer.Builder().nOut(X_DIM*Y_DIM*CHANNELS*2).build(), //input is set by setInputType on the network
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DropoutLayer.Builder(1 - 0.5).build(),
new DenseLayer.Builder().nIn(1024).nOut(512).build(),
new DenseLayer.Builder().nIn(X_DIM * Y_DIM*CHANNELS*2).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DropoutLayer.Builder(1 - 0.5).build(),
new DenseLayer.Builder().nIn(512).nOut(256).build(),
new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DropoutLayer.Builder(1 - 0.5).build(),
new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DropoutLayer.Builder(1 - 0.5).build(),
new OutputLayer.Builder(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
};
}
@ -110,6 +155,7 @@ public class App {
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.list(disLayers())
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
.build();
return conf;
@ -135,6 +181,7 @@ public class App {
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.list(layers)
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
.build();
return conf;
@ -149,7 +196,25 @@ public class App {
public static void main(String... args) throws Exception {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 42);
// MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
// FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
ImageTransform tr = new PipelineImageTransform.Builder()
.addImageTransform(transform) //convert to GREY SCALE
.addImageTransform(transform3)
//.addImageTransform(transform2)
.build();
ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
imageRecordReader.initialize(fileSplit, tr);
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
@ -160,27 +225,50 @@ public class App {
copyParams(gen, dis, gan);
gen.setListeners(new PerformanceListener(10, true));
dis.setListeners(new PerformanceListener(10, true));
gan.setListeners(new PerformanceListener(10, true));
//gen.setListeners(new PerformanceListener(10, true));
//dis.setListeners(new PerformanceListener(10, true));
//gan.setListeners(new PerformanceListener(10, true));
gan.setListeners(new ScoreToChartListener("gan"));
//dis.setListeners(new ScoreToChartListener("dis"));
trainData.reset();
gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
//trainData.reset();
int j = 0;
for (int i = 0; i < 20; i++) {
for (int i = 0; i < 201; i++) { //epoch
while (trainData.hasNext()) {
j++;
DataSet next = trainData.next();
// generate data
INDArray real = trainData.next().getFeatures().muli(2).subi(1);
int batchSize = (int) real.shape()[0];
INDArray real = next.getFeatures();//.div(255f);
INDArray fakeIn = Nd4j.rand(batchSize, 100);
//start next round if there are not enough images left to have a full batchsize dataset
if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
}
if(i%20 == 0) {
// frame2 = visualize(new INDArray[]{real}, batchSize,
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
}
real.divi(255f);
// int batchSize = (int) real.shape()[0];
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
//log.info("real has {} items.", real.length());
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data);
@ -189,21 +277,29 @@ public class App {
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)));
if (j % 10 == 1) {
System.out.println("Iteration " + j + " Visualizing...");
INDArray[] samples = new INDArray[9];
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
for (int k = 0; k < 9; k++) {
for (int k = 0; k < samples.length; k++) {
//INDArray input = fakeSet2.get(k).getFeatures();
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
INDArray input = fakeSet2.get(k).getFeatures();
input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
//samples[k] = gen.output(input, false);
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
//samples[k] =
samples[k].addi(1f).divi(2f).muli(255f);
}
visualize(samples);
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
}
}
trainData.reset();
@ -239,41 +335,57 @@ public class App {
}
}
private static void visualize(INDArray[] samples) {
if (frame == null) {
frame = new JFrame();
frame.setTitle("Viz");
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
if (isOrig) {
frame.setTitle("Viz Original");
} else {
frame.setTitle("Generated");
}
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
panel = new JPanel();
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
frame.add(panel, BorderLayout.CENTER);
frame.setVisible(true);
}
panel.removeAll();
JPanel panelx = new JPanel();
panelx.setLayout(new GridLayout(4, 4, 8, 8));
for (INDArray sample : samples) {
panel.add(getImage(sample));
for(int i = 0; i<batchElements; i++) {
panelx.add(getImage(sample, i, isOrig));
}
}
frame.add(panelx, BorderLayout.CENTER);
frame.setVisible(true);
frame.revalidate();
frame.setMinimumSize(new Dimension(300, 20));
frame.pack();
return frame;
}
private static JLabel getImage(INDArray tensor) {
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
for (int i = 0; i < 784; i++) {
int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
final BufferedImage bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY);
final int imageSize = X_DIM * Y_DIM;
final int offset = batchElement * imageSize;
int pxl = offset * CHANNELS; //where to start in the INDArray
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < Y_DIM; y++) { // step through the columns x
for (int x = 0; x < X_DIM; x++) { //step through the rows y
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
pxl++; //next item in INDArray
}
}
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
return new JLabel(scaled);
}
}

View File

@ -0,0 +1,49 @@
#
#
# ******************************************************************************
# *
# * This program and the accompanying materials are made available under the
# * terms of the Apache License, Version 2.0 which is available at
# * https://www.apache.org/licenses/LICENSE-2.0.
# *
# * See the NOTICE file distributed with this work for additional
# * information regarding copyright ownership.
# * Unless required by applicable law or agreed to in writing, software
# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# * License for the specific language governing permissions and limitations
# * under the License.
# *
# * SPDX-License-Identifier: Apache-2.0
# *****************************************************************************
#
#
# SLF4J's SimpleLogger configuration file
# Simple implementation of Logger that sends all enabled log messages, for all defined loggers, to System.err.
# Default logging detail level for all instances of SimpleLogger.
# Must be one of ("trace", "debug", "info", "warn", or "error").
# If not specified, defaults to "info".
org.slf4j.simpleLogger.defaultLogLevel=trace
# Logging detail level for a SimpleLogger instance named "xxxxx".
# Must be one of ("trace", "debug", "info", "warn", or "error").
# If not specified, the default logging detail level is used.
#org.slf4j.simpleLogger.log.xxxxx=
#org.slf4j.simpleLogger.log.net.brutex.cavis.backend.cavisrest.JWTAuthenticationFilter=warn
# Set to true if you want the current date and time to be included in output messages.
# Default is false, and will output the number of milliseconds elapsed since startup.
#org.slf4j.simpleLogger.showDateTime=false
# The date and time format to be used in the output messages.
# The pattern describing the date and time format is the same that is used in java.text.SimpleDateFormat.
# If the format is not specified or is invalid, the default format is used.
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
# Set to true if you want to output the current thread name.
# Defaults to true.
org.slf4j.simpleLogger.showThreadName=true

View File

@ -30,6 +30,8 @@ ext {
def netty = [version: "4.1.68.Final"]
def okhttp3 = [version: "4.10.0"]
javaPlatform {
allowDependencies()
@ -40,12 +42,16 @@ dependencies {
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
//api enforcedPlatform("com.fasterxml.jackson.core:jackson-annotations:${jackson.version}")
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
constraints {
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
//api enforcedPlatform("com.fasterxml.jackson.core:jackson-annotations:${jackson.version}")
//api "com.squareup.okhttp3:okhttp:${okhttp3}.version"
//api "com.squareup.okhttp3:logging-interceptor:${okhttp3}.version"
api 'com.google.guava:guava:30.1-jre'
api "com.google.protobuf:protobuf-java:3.15.6"
@ -157,6 +163,7 @@ dependencies {
api "org.agrona:agrona:1.12.0"
}
}

View File

@ -22,6 +22,7 @@ package org.datavec.image.transform;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable;
import com.fasterxml.jackson.annotation.JsonInclude;
@ -35,6 +36,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*;
@JsonInclude(JsonInclude.Include.NON_NULL)
@Data
@EqualsAndHashCode(callSuper = false)
@Slf4j
public class ColorConversionTransform extends BaseImageTransform {
/**
@ -85,14 +87,18 @@ public class ColorConversionTransform extends BaseImageTransform {
return null;
}
Mat mat = (Mat) converter.convert(image.getFrame());
Mat result = new Mat();
if(mat.type() != result.type() ) {
try {
cvtColor(mat, result, conversionCode);
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
log.debug("Image is already at type {}. No conversion done.", mat.type());
return image;
}
return new ImageWritable(converter.convert(result));
}

View File

@ -85,6 +85,7 @@ public class ShowImageTransform extends BaseImageTransform {
if (!canvas.isVisible()) {
return image;
}
Frame frame = image.getFrame();
canvas.setCanvasSize(frame.imageWidth, frame.imageHeight);
canvas.showImage(frame);

View File

@ -5171,22 +5171,22 @@ public class Nd4j {
Class<? extends DistributionFactory> distributionFactoryClazz = ND4JClassLoading.loadClassByName(clazzName);
memoryManager = memoryManagerClazz.newInstance();
constantHandler = constantProviderClazz.newInstance();
shapeInfoProvider = shapeInfoProviderClazz.newInstance();
workspaceManager = workspaceManagerClazz.newInstance();
memoryManager = memoryManagerClazz.getDeclaredConstructor().newInstance();
constantHandler = constantProviderClazz.getDeclaredConstructor().newInstance();
shapeInfoProvider = shapeInfoProviderClazz.getDeclaredConstructor().newInstance();
workspaceManager = workspaceManagerClazz.getDeclaredConstructor().newInstance();
Class<? extends OpExecutioner> opExecutionerClazz = ND4JClassLoading
.loadClassByName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName()));
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.getDeclaredConstructor().newInstance();
Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class);
INSTANCE = (NDArrayFactory) c2.newInstance(dtype, ORDER);
CONVOLUTION_INSTANCE = convolutionInstanceClazz.newInstance();
BLAS_WRAPPER_INSTANCE = blasWrapperClazz.newInstance();
DATA_BUFFER_FACTORY_INSTANCE = dataBufferFactoryClazz.newInstance();
CONVOLUTION_INSTANCE = convolutionInstanceClazz.getDeclaredConstructor().newInstance();
BLAS_WRAPPER_INSTANCE = blasWrapperClazz.getDeclaredConstructor().newInstance();
DATA_BUFFER_FACTORY_INSTANCE = dataBufferFactoryClazz.getDeclaredConstructor().newInstance();
DISTRIBUTION_FACTORY = distributionFactoryClazz.newInstance();
DISTRIBUTION_FACTORY = distributionFactoryClazz.getDeclaredConstructor().newInstance();
if (isFallback()) {
fallbackMode.set(true);

View File

@ -58,11 +58,13 @@ public final class ND4JClassLoading {
@SuppressWarnings("unchecked")
public static <T> Class<T> loadClassByName(String className, boolean initialize, ClassLoader classLoader) {
try {
log.info(String.format("Trying to load class [%s]", className));
return (Class<T>) Class.forName(className, initialize, classLoader);
Class<T> clazz = (Class<T>) Class.forName(className, initialize, classLoader);
log.info(String.format("Trying to load class [%s] - Success", className));
return clazz;
} catch (ClassNotFoundException classNotFoundException) {
log.error(String.format("Cannot find class [%s] of provided class-loader.", className));
log.error(String.format("Trying to load class [%s] - Failure: Cannot find class with provided class-loader.", className));
return null;
}
}

View File

@ -21,6 +21,8 @@
apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
dependencies {
implementation platform(projects.cavisCommonPlatform)
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
implementation 'org.lucee:oswego-concurrent:1.3.4'
implementation projects.cavisDnn.cavisDnnCommon
@ -50,4 +52,9 @@ dependencies {
implementation "com.fasterxml.jackson.core:jackson-databind"
implementation "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml"
implementation "com.jakewharton.byteunits:byteunits:0.9.1"
//Rest Client
// define any required OkHttp artifacts without version
implementation "com.squareup.okhttp3:okhttp"
implementation "com.squareup.okhttp3:logging-interceptor"
}

View File

@ -215,7 +215,7 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable {
/**
* Get the updater for the given parameter. Typically the same updater will be used for all
* updaters, but this is not necessarily the case
* parameters, but this is not necessarily the case
*
* @param paramName Parameter name
* @return IUpdater for the parameter

View File

@ -30,6 +30,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
* @author Adam Gibson
*/
public class DenseLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.DenseLayer> {
public DenseLayer(NeuralNetConfiguration conf, DataType dataType) {
super(conf, dataType);
}

View File

@ -0,0 +1,62 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package org.deeplearning4j.optimize.listeners;
import java.io.IOException;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
@Slf4j
public class ScoreToChartListener extends BaseTrainingListener {
final String url = "http://bru5:8080/cavis-rest-1.0-SNAPSHOT.war/hello/hello-world?";
final String seriesName;
public ScoreToChartListener(String seriesName) {
this.seriesName = seriesName;
}
@Override
public void iterationDone(Model model, int iteration, int epoch) {
double score = model.score();
String nurl = url+"s="+score+"&n="+seriesName;
OkHttpClient client = new OkHttpClient();
Request request = new Request.Builder()
.url(nurl)
.build();
try {
Response response = client.newCall(request).execute();
log.debug(String.format("Did send score to chart at '%s'.", nurl));
response.body().close();
} catch (IOException e) {
log.warn(String.format("Could not send score to chart at '%s' because %s", nurl, e.getMessage()));
}
//response.body().string();
}
}

View File

@ -37,7 +37,6 @@ public class NativeOpsGPUInfoProvider implements GPUInfoProvider {
List<GPUInfo> gpus = new ArrayList<>();
int nDevices = nativeOps.getAvailableDevices();
if (nDevices > 0) {
for (int i = 0; i < nDevices; i++) {

View File

@ -83,7 +83,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
*/
@Slf4j
public class AtomicAllocator implements Allocator {
private static final AtomicAllocator INSTANCE = new AtomicAllocator();
private static AtomicAllocator INSTANCE = new AtomicAllocator();
private Configuration configuration;
@ -122,6 +122,7 @@ public class AtomicAllocator implements Allocator {
private final AtomicLong useTracker = new AtomicLong(System.currentTimeMillis());
public static AtomicAllocator getInstance() {
if(INSTANCE == null) INSTANCE = new AtomicAllocator();
if (INSTANCE == null)
throw new RuntimeException("AtomicAllocator is NULL");
return INSTANCE;

View File

@ -402,6 +402,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
val ctx = AtomicAllocator.getInstance().getDeviceContext();
val devicePtr = allocationPoint.getDevicePointer();
NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream());
int ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if(ec != 0) {
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
}
ctx.getSpecialStream().synchronize();
}

View File

@ -0,0 +1,55 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package org.nd4j.linalg.jcublas.buffer;
import static org.junit.jupiter.api.Assertions.*;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder;
@Slf4j
class BaseCudaDataBufferTest {
@Test
public void testMemoryAlloc() throws InterruptedException {
BaseCudaDataBuffer cuBuffer = new CudaLongDataBuffer(16l);
log.info(
"Allocation Status: " + cuBuffer.getAllocationPoint().getAllocationStatus().toString());
Thread.sleep(3000);
cuBuffer.getAllocationPoint().tickDeviceWrite();
DataBuffer buf = Nd4j.rand(8,1).shapeInfoDataBuffer();
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpySync(cuBuffer.pointer(), buf.pointer(), 8, 0, new Pointer() );
log.info(
"Allocation Status: " + cuBuffer.getAllocationPoint().getAllocationStatus().toString());
cuBuffer.release();
}
}