parent
aab7b423d1
commit
42fb4bd48e
|
@ -34,6 +34,8 @@ ext {
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
implementation platform(projects.cavisCommonPlatform)
|
||||||
|
|
||||||
implementation "com.fasterxml.jackson.core:jackson-databind"
|
implementation "com.fasterxml.jackson.core:jackson-databind"
|
||||||
implementation "com.google.guava:guava"
|
implementation "com.google.guava:guava"
|
||||||
implementation projects.cavisDnn.cavisDnnCore
|
implementation projects.cavisDnn.cavisDnnCore
|
||||||
|
@ -52,6 +54,16 @@ dependencies {
|
||||||
testImplementation "org.apache.spark:spark-sql_${scalaVersion}"
|
testImplementation "org.apache.spark:spark-sql_${scalaVersion}"
|
||||||
testCompileOnly "org.scala-lang:scala-library"
|
testCompileOnly "org.scala-lang:scala-library"
|
||||||
|
|
||||||
|
//Rest Client
|
||||||
|
// define any required OkHttp artifacts without version
|
||||||
|
implementation("com.squareup.okhttp3:okhttp")
|
||||||
|
implementation("com.squareup.okhttp3:logging-interceptor")
|
||||||
|
|
||||||
|
|
||||||
|
implementation "org.bytedeco:javacv"
|
||||||
|
implementation "org.bytedeco:opencv"
|
||||||
|
implementation group: "org.bytedeco", name: "opencv", classifier: buildTarget
|
||||||
|
|
||||||
implementation "it.unimi.dsi:fastutil-core:8.5.8"
|
implementation "it.unimi.dsi:fastutil-core:8.5.8"
|
||||||
|
|
||||||
implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkCore
|
implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkCore
|
||||||
|
|
|
@ -21,49 +21,90 @@
|
||||||
|
|
||||||
package net.brutex.gan;
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import java.util.Random;
|
||||||
|
import javax.ws.rs.client.ClientBuilder;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import okhttp3.OkHttpClient;
|
||||||
|
import okhttp3.Request;
|
||||||
|
import okhttp3.Response;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.datavec.api.Writable;
|
||||||
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
|
import org.datavec.api.split.FileSplit;
|
||||||
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
|
import org.datavec.image.transform.ColorConversionTransform;
|
||||||
|
import org.datavec.image.transform.ImageTransform;
|
||||||
|
import org.datavec.image.transform.PipelineImageTransform;
|
||||||
|
import org.datavec.image.transform.ResizeImageTransform;
|
||||||
|
import org.datavec.image.transform.ScaleImageTransform;
|
||||||
|
import org.datavec.image.transform.ShowImageTransform;
|
||||||
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
|
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
||||||
|
import org.glassfish.jersey.client.JerseyClient;
|
||||||
|
import org.glassfish.jersey.client.JerseyClientBuilder;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
import java.awt.image.BufferedImage;
|
import java.awt.image.BufferedImage;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
public class App {
|
public class App {
|
||||||
private static final double LEARNING_RATE = 0.0002;
|
private static final double LEARNING_RATE = 0.000002;
|
||||||
private static final double GRADIENT_THRESHOLD = 100.0;
|
private static final double GRADIENT_THRESHOLD = 100.0;
|
||||||
|
|
||||||
|
private static final int X_DIM = 28;
|
||||||
|
private static final int Y_DIM = 28;
|
||||||
|
private static final int CHANNELS = 1;
|
||||||
|
private static final int batchSize = 9;
|
||||||
|
private static final int INPUT = 128;
|
||||||
|
|
||||||
|
private static final int OUTPUT_PER_PANEL = 4;
|
||||||
|
|
||||||
|
private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
|
||||||
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
||||||
|
|
||||||
private static JFrame frame;
|
private static JFrame frame;
|
||||||
|
private static JFrame frame2;
|
||||||
private static JPanel panel;
|
private static JPanel panel;
|
||||||
|
private static JPanel panel2;
|
||||||
|
|
||||||
private static Layer[] genLayers() {
|
private static Layer[] genLayers() {
|
||||||
return new Layer[] {
|
return new Layer[] {
|
||||||
new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
|
new DenseLayer.Builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
||||||
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
new DenseLayer.Builder().nIn(256).nOut(512).build(),
|
new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
||||||
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
new DenseLayer.Builder().nIn(512).nOut(1024).build(),
|
new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
||||||
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build()
|
new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH)
|
||||||
|
.build()
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,6 +122,7 @@ public class App {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.activation(Activation.IDENTITY)
|
.activation(Activation.IDENTITY)
|
||||||
.list(genLayers())
|
.list(genLayers())
|
||||||
|
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
return conf;
|
return conf;
|
||||||
|
@ -88,16 +130,19 @@ public class App {
|
||||||
|
|
||||||
private static Layer[] disLayers() {
|
private static Layer[] disLayers() {
|
||||||
return new Layer[]{
|
return new Layer[]{
|
||||||
new DenseLayer.Builder().nIn(784).nOut(1024).build(),
|
new DenseLayer.Builder().nOut(X_DIM*Y_DIM*CHANNELS*2).build(), //input is set by setInputType on the network
|
||||||
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
new DropoutLayer.Builder(1 - 0.5).build(),
|
new DropoutLayer.Builder(1 - 0.5).build(),
|
||||||
new DenseLayer.Builder().nIn(1024).nOut(512).build(),
|
new DenseLayer.Builder().nIn(X_DIM * Y_DIM*CHANNELS*2).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
|
||||||
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
new DropoutLayer.Builder(1 - 0.5).build(),
|
new DropoutLayer.Builder(1 - 0.5).build(),
|
||||||
new DenseLayer.Builder().nIn(512).nOut(256).build(),
|
new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
|
||||||
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
new DropoutLayer.Builder(1 - 0.5).build(),
|
new DropoutLayer.Builder(1 - 0.5).build(),
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
|
new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
||||||
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
new DropoutLayer.Builder(1 - 0.5).build(),
|
||||||
|
new OutputLayer.Builder(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,6 +155,7 @@ public class App {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.activation(Activation.IDENTITY)
|
.activation(Activation.IDENTITY)
|
||||||
.list(disLayers())
|
.list(disLayers())
|
||||||
|
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
return conf;
|
return conf;
|
||||||
|
@ -135,6 +181,7 @@ public class App {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.activation(Activation.IDENTITY)
|
.activation(Activation.IDENTITY)
|
||||||
.list(layers)
|
.list(layers)
|
||||||
|
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
return conf;
|
return conf;
|
||||||
|
@ -149,7 +196,25 @@ public class App {
|
||||||
public static void main(String... args) throws Exception {
|
public static void main(String... args) throws Exception {
|
||||||
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 42);
|
// MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
|
||||||
|
// FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
|
||||||
|
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
|
||||||
|
|
||||||
|
|
||||||
|
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
||||||
|
|
||||||
|
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
||||||
|
ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
|
||||||
|
|
||||||
|
ImageTransform tr = new PipelineImageTransform.Builder()
|
||||||
|
.addImageTransform(transform) //convert to GREY SCALE
|
||||||
|
.addImageTransform(transform3)
|
||||||
|
//.addImageTransform(transform2)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
|
||||||
|
imageRecordReader.initialize(fileSplit, tr);
|
||||||
|
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
|
||||||
|
|
||||||
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
||||||
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
||||||
|
@ -160,27 +225,50 @@ public class App {
|
||||||
|
|
||||||
copyParams(gen, dis, gan);
|
copyParams(gen, dis, gan);
|
||||||
|
|
||||||
gen.setListeners(new PerformanceListener(10, true));
|
//gen.setListeners(new PerformanceListener(10, true));
|
||||||
dis.setListeners(new PerformanceListener(10, true));
|
//dis.setListeners(new PerformanceListener(10, true));
|
||||||
gan.setListeners(new PerformanceListener(10, true));
|
//gan.setListeners(new PerformanceListener(10, true));
|
||||||
|
gan.setListeners(new ScoreToChartListener("gan"));
|
||||||
|
//dis.setListeners(new ScoreToChartListener("dis"));
|
||||||
|
|
||||||
trainData.reset();
|
gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
|
||||||
|
|
||||||
|
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
|
||||||
|
//trainData.reset();
|
||||||
|
|
||||||
int j = 0;
|
int j = 0;
|
||||||
for (int i = 0; i < 20; i++) {
|
for (int i = 0; i < 201; i++) { //epoch
|
||||||
while (trainData.hasNext()) {
|
while (trainData.hasNext()) {
|
||||||
j++;
|
j++;
|
||||||
|
|
||||||
|
DataSet next = trainData.next();
|
||||||
// generate data
|
// generate data
|
||||||
INDArray real = trainData.next().getFeatures().muli(2).subi(1);
|
INDArray real = next.getFeatures();//.div(255f);
|
||||||
int batchSize = (int) real.shape()[0];
|
|
||||||
|
|
||||||
INDArray fakeIn = Nd4j.rand(batchSize, 100);
|
//start next round if there are not enough images left to have a full batchsize dataset
|
||||||
|
if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
|
||||||
|
log.warn("Your total number of input images is not a multiple of {}, "
|
||||||
|
+ "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(i%20 == 0) {
|
||||||
|
// frame2 = visualize(new INDArray[]{real}, batchSize,
|
||||||
|
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
|
||||||
|
}
|
||||||
|
real.divi(255f);
|
||||||
|
|
||||||
|
// int batchSize = (int) real.shape()[0];
|
||||||
|
|
||||||
|
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
|
||||||
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||||
|
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
|
||||||
|
|
||||||
|
//log.info("real has {} items.", real.length());
|
||||||
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
||||||
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
||||||
|
|
||||||
|
|
||||||
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||||
|
|
||||||
dis.fit(data);
|
dis.fit(data);
|
||||||
|
@ -189,21 +277,29 @@ public class App {
|
||||||
// Update the discriminator in the GAN network
|
// Update the discriminator in the GAN network
|
||||||
updateGan(gen, dis, gan);
|
updateGan(gen, dis, gan);
|
||||||
|
|
||||||
gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
|
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
|
||||||
|
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)));
|
||||||
|
|
||||||
|
|
||||||
if (j % 10 == 1) {
|
if (j % 10 == 1) {
|
||||||
System.out.println("Iteration " + j + " Visualizing...");
|
System.out.println("Iteration " + j + " Visualizing...");
|
||||||
INDArray[] samples = new INDArray[9];
|
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
|
||||||
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
|
||||||
|
|
||||||
for (int k = 0; k < 9; k++) {
|
|
||||||
|
for (int k = 0; k < samples.length; k++) {
|
||||||
|
//INDArray input = fakeSet2.get(k).getFeatures();
|
||||||
|
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
||||||
INDArray input = fakeSet2.get(k).getFeatures();
|
INDArray input = fakeSet2.get(k).getFeatures();
|
||||||
|
input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
|
||||||
|
|
||||||
//samples[k] = gen.output(input, false);
|
//samples[k] = gen.output(input, false);
|
||||||
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||||
|
samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
|
||||||
|
//samples[k] =
|
||||||
|
samples[k].addi(1f).divi(2f).muli(255f);
|
||||||
|
|
||||||
}
|
}
|
||||||
visualize(samples);
|
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
trainData.reset();
|
trainData.reset();
|
||||||
|
@ -239,41 +335,57 @@ public class App {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void visualize(INDArray[] samples) {
|
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
|
||||||
if (frame == null) {
|
if (isOrig) {
|
||||||
frame = new JFrame();
|
frame.setTitle("Viz Original");
|
||||||
frame.setTitle("Viz");
|
} else {
|
||||||
|
frame.setTitle("Generated");
|
||||||
|
}
|
||||||
|
|
||||||
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
frame.setLayout(new BorderLayout());
|
frame.setLayout(new BorderLayout());
|
||||||
|
|
||||||
panel = new JPanel();
|
JPanel panelx = new JPanel();
|
||||||
|
|
||||||
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
|
|
||||||
frame.add(panel, BorderLayout.CENTER);
|
|
||||||
frame.setVisible(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
panel.removeAll();
|
|
||||||
|
|
||||||
|
panelx.setLayout(new GridLayout(4, 4, 8, 8));
|
||||||
for (INDArray sample : samples) {
|
for (INDArray sample : samples) {
|
||||||
panel.add(getImage(sample));
|
for(int i = 0; i<batchElements; i++) {
|
||||||
|
panelx.add(getImage(sample, i, isOrig));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
frame.add(panelx, BorderLayout.CENTER);
|
||||||
|
frame.setVisible(true);
|
||||||
|
|
||||||
frame.revalidate();
|
frame.revalidate();
|
||||||
|
frame.setMinimumSize(new Dimension(300, 20));
|
||||||
frame.pack();
|
frame.pack();
|
||||||
|
return frame;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static JLabel getImage(INDArray tensor) {
|
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
||||||
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
|
final BufferedImage bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY);
|
||||||
for (int i = 0; i < 784; i++) {
|
final int imageSize = X_DIM * Y_DIM;
|
||||||
int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
|
final int offset = batchElement * imageSize;
|
||||||
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
|
int pxl = offset * CHANNELS; //where to start in the INDArray
|
||||||
|
|
||||||
|
//Image in NCHW - channels first format
|
||||||
|
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
||||||
|
for (int y = 0; y < Y_DIM; y++) { // step through the columns x
|
||||||
|
for (int x = 0; x < X_DIM; x++) { //step through the rows y
|
||||||
|
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
|
||||||
|
bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
|
||||||
|
pxl++; //next item in INDArray
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ImageIcon orig = new ImageIcon(bi);
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
|
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
|
||||||
|
|
||||||
ImageIcon scaled = new ImageIcon(imageScaled);
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
|
||||||
return new JLabel(scaled);
|
return new JLabel(scaled);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# ******************************************************************************
|
||||||
|
# *
|
||||||
|
# * This program and the accompanying materials are made available under the
|
||||||
|
# * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
# *
|
||||||
|
# * See the NOTICE file distributed with this work for additional
|
||||||
|
# * information regarding copyright ownership.
|
||||||
|
# * Unless required by applicable law or agreed to in writing, software
|
||||||
|
# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# * License for the specific language governing permissions and limitations
|
||||||
|
# * under the License.
|
||||||
|
# *
|
||||||
|
# * SPDX-License-Identifier: Apache-2.0
|
||||||
|
# *****************************************************************************
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# SLF4J's SimpleLogger configuration file
|
||||||
|
# Simple implementation of Logger that sends all enabled log messages, for all defined loggers, to System.err.
|
||||||
|
|
||||||
|
# Default logging detail level for all instances of SimpleLogger.
|
||||||
|
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
||||||
|
# If not specified, defaults to "info".
|
||||||
|
org.slf4j.simpleLogger.defaultLogLevel=trace
|
||||||
|
|
||||||
|
# Logging detail level for a SimpleLogger instance named "xxxxx".
|
||||||
|
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
||||||
|
# If not specified, the default logging detail level is used.
|
||||||
|
#org.slf4j.simpleLogger.log.xxxxx=
|
||||||
|
#org.slf4j.simpleLogger.log.net.brutex.cavis.backend.cavisrest.JWTAuthenticationFilter=warn
|
||||||
|
|
||||||
|
# Set to true if you want the current date and time to be included in output messages.
|
||||||
|
# Default is false, and will output the number of milliseconds elapsed since startup.
|
||||||
|
#org.slf4j.simpleLogger.showDateTime=false
|
||||||
|
|
||||||
|
# The date and time format to be used in the output messages.
|
||||||
|
# The pattern describing the date and time format is the same that is used in java.text.SimpleDateFormat.
|
||||||
|
# If the format is not specified or is invalid, the default format is used.
|
||||||
|
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
|
||||||
|
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
|
||||||
|
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
|
||||||
|
|
||||||
|
# Set to true if you want to output the current thread name.
|
||||||
|
# Defaults to true.
|
||||||
|
org.slf4j.simpleLogger.showThreadName=true
|
|
@ -30,6 +30,8 @@ ext {
|
||||||
|
|
||||||
def netty = [version: "4.1.68.Final"]
|
def netty = [version: "4.1.68.Final"]
|
||||||
|
|
||||||
|
def okhttp3 = [version: "4.10.0"]
|
||||||
|
|
||||||
|
|
||||||
javaPlatform {
|
javaPlatform {
|
||||||
allowDependencies()
|
allowDependencies()
|
||||||
|
@ -40,12 +42,16 @@ dependencies {
|
||||||
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
||||||
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
||||||
//api enforcedPlatform("com.fasterxml.jackson.core:jackson-annotations:${jackson.version}")
|
//api enforcedPlatform("com.fasterxml.jackson.core:jackson-annotations:${jackson.version}")
|
||||||
|
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
|
||||||
|
|
||||||
|
|
||||||
constraints {
|
constraints {
|
||||||
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
||||||
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
||||||
|
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
|
||||||
//api enforcedPlatform("com.fasterxml.jackson.core:jackson-annotations:${jackson.version}")
|
//api enforcedPlatform("com.fasterxml.jackson.core:jackson-annotations:${jackson.version}")
|
||||||
|
//api "com.squareup.okhttp3:okhttp:${okhttp3}.version"
|
||||||
|
//api "com.squareup.okhttp3:logging-interceptor:${okhttp3}.version"
|
||||||
|
|
||||||
api 'com.google.guava:guava:30.1-jre'
|
api 'com.google.guava:guava:30.1-jre'
|
||||||
api "com.google.protobuf:protobuf-java:3.15.6"
|
api "com.google.protobuf:protobuf-java:3.15.6"
|
||||||
|
@ -157,6 +163,7 @@ dependencies {
|
||||||
|
|
||||||
api "org.agrona:agrona:1.12.0"
|
api "org.agrona:agrona:1.12.0"
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
@ -35,6 +36,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
@Slf4j
|
||||||
public class ColorConversionTransform extends BaseImageTransform {
|
public class ColorConversionTransform extends BaseImageTransform {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -85,14 +87,18 @@ public class ColorConversionTransform extends BaseImageTransform {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
Mat mat = (Mat) converter.convert(image.getFrame());
|
Mat mat = (Mat) converter.convert(image.getFrame());
|
||||||
|
|
||||||
Mat result = new Mat();
|
Mat result = new Mat();
|
||||||
|
|
||||||
|
if(mat.type() != result.type() ) {
|
||||||
try {
|
try {
|
||||||
cvtColor(mat, result, conversionCode);
|
cvtColor(mat, result, conversionCode);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
log.debug("Image is already at type {}. No conversion done.", mat.type());
|
||||||
|
return image;
|
||||||
|
}
|
||||||
|
|
||||||
return new ImageWritable(converter.convert(result));
|
return new ImageWritable(converter.convert(result));
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,6 +85,7 @@ public class ShowImageTransform extends BaseImageTransform {
|
||||||
if (!canvas.isVisible()) {
|
if (!canvas.isVisible()) {
|
||||||
return image;
|
return image;
|
||||||
}
|
}
|
||||||
|
|
||||||
Frame frame = image.getFrame();
|
Frame frame = image.getFrame();
|
||||||
canvas.setCanvasSize(frame.imageWidth, frame.imageHeight);
|
canvas.setCanvasSize(frame.imageWidth, frame.imageHeight);
|
||||||
canvas.showImage(frame);
|
canvas.showImage(frame);
|
||||||
|
|
|
@ -5171,22 +5171,22 @@ public class Nd4j {
|
||||||
Class<? extends DistributionFactory> distributionFactoryClazz = ND4JClassLoading.loadClassByName(clazzName);
|
Class<? extends DistributionFactory> distributionFactoryClazz = ND4JClassLoading.loadClassByName(clazzName);
|
||||||
|
|
||||||
|
|
||||||
memoryManager = memoryManagerClazz.newInstance();
|
memoryManager = memoryManagerClazz.getDeclaredConstructor().newInstance();
|
||||||
constantHandler = constantProviderClazz.newInstance();
|
constantHandler = constantProviderClazz.getDeclaredConstructor().newInstance();
|
||||||
shapeInfoProvider = shapeInfoProviderClazz.newInstance();
|
shapeInfoProvider = shapeInfoProviderClazz.getDeclaredConstructor().newInstance();
|
||||||
workspaceManager = workspaceManagerClazz.newInstance();
|
workspaceManager = workspaceManagerClazz.getDeclaredConstructor().newInstance();
|
||||||
|
|
||||||
Class<? extends OpExecutioner> opExecutionerClazz = ND4JClassLoading
|
Class<? extends OpExecutioner> opExecutionerClazz = ND4JClassLoading
|
||||||
.loadClassByName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName()));
|
.loadClassByName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName()));
|
||||||
|
|
||||||
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();
|
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.getDeclaredConstructor().newInstance();
|
||||||
Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class);
|
Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class);
|
||||||
INSTANCE = (NDArrayFactory) c2.newInstance(dtype, ORDER);
|
INSTANCE = (NDArrayFactory) c2.newInstance(dtype, ORDER);
|
||||||
CONVOLUTION_INSTANCE = convolutionInstanceClazz.newInstance();
|
CONVOLUTION_INSTANCE = convolutionInstanceClazz.getDeclaredConstructor().newInstance();
|
||||||
BLAS_WRAPPER_INSTANCE = blasWrapperClazz.newInstance();
|
BLAS_WRAPPER_INSTANCE = blasWrapperClazz.getDeclaredConstructor().newInstance();
|
||||||
DATA_BUFFER_FACTORY_INSTANCE = dataBufferFactoryClazz.newInstance();
|
DATA_BUFFER_FACTORY_INSTANCE = dataBufferFactoryClazz.getDeclaredConstructor().newInstance();
|
||||||
|
|
||||||
DISTRIBUTION_FACTORY = distributionFactoryClazz.newInstance();
|
DISTRIBUTION_FACTORY = distributionFactoryClazz.getDeclaredConstructor().newInstance();
|
||||||
|
|
||||||
if (isFallback()) {
|
if (isFallback()) {
|
||||||
fallbackMode.set(true);
|
fallbackMode.set(true);
|
||||||
|
|
|
@ -58,11 +58,13 @@ public final class ND4JClassLoading {
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public static <T> Class<T> loadClassByName(String className, boolean initialize, ClassLoader classLoader) {
|
public static <T> Class<T> loadClassByName(String className, boolean initialize, ClassLoader classLoader) {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
log.info(String.format("Trying to load class [%s]", className));
|
Class<T> clazz = (Class<T>) Class.forName(className, initialize, classLoader);
|
||||||
return (Class<T>) Class.forName(className, initialize, classLoader);
|
log.info(String.format("Trying to load class [%s] - Success", className));
|
||||||
|
return clazz;
|
||||||
} catch (ClassNotFoundException classNotFoundException) {
|
} catch (ClassNotFoundException classNotFoundException) {
|
||||||
log.error(String.format("Cannot find class [%s] of provided class-loader.", className));
|
log.error(String.format("Trying to load class [%s] - Failure: Cannot find class with provided class-loader.", className));
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
implementation platform(projects.cavisCommonPlatform)
|
||||||
|
|
||||||
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
|
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
|
||||||
implementation 'org.lucee:oswego-concurrent:1.3.4'
|
implementation 'org.lucee:oswego-concurrent:1.3.4'
|
||||||
implementation projects.cavisDnn.cavisDnnCommon
|
implementation projects.cavisDnn.cavisDnnCommon
|
||||||
|
@ -50,4 +52,9 @@ dependencies {
|
||||||
implementation "com.fasterxml.jackson.core:jackson-databind"
|
implementation "com.fasterxml.jackson.core:jackson-databind"
|
||||||
implementation "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml"
|
implementation "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml"
|
||||||
implementation "com.jakewharton.byteunits:byteunits:0.9.1"
|
implementation "com.jakewharton.byteunits:byteunits:0.9.1"
|
||||||
|
|
||||||
|
//Rest Client
|
||||||
|
// define any required OkHttp artifacts without version
|
||||||
|
implementation "com.squareup.okhttp3:okhttp"
|
||||||
|
implementation "com.squareup.okhttp3:logging-interceptor"
|
||||||
}
|
}
|
|
@ -215,7 +215,7 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the updater for the given parameter. Typically the same updater will be used for all
|
* Get the updater for the given parameter. Typically the same updater will be used for all
|
||||||
* updaters, but this is not necessarily the case
|
* parameters, but this is not necessarily the case
|
||||||
*
|
*
|
||||||
* @param paramName Parameter name
|
* @param paramName Parameter name
|
||||||
* @return IUpdater for the parameter
|
* @return IUpdater for the parameter
|
||||||
|
|
|
@ -30,6 +30,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
public class DenseLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.DenseLayer> {
|
public class DenseLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.DenseLayer> {
|
||||||
|
|
||||||
public DenseLayer(NeuralNetConfiguration conf, DataType dataType) {
|
public DenseLayer(NeuralNetConfiguration conf, DataType dataType) {
|
||||||
super(conf, dataType);
|
super(conf, dataType);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.optimize.listeners;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import okhttp3.OkHttpClient;
|
||||||
|
import okhttp3.Request;
|
||||||
|
import okhttp3.Response;
|
||||||
|
import org.deeplearning4j.nn.api.Model;
|
||||||
|
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class ScoreToChartListener extends BaseTrainingListener {
|
||||||
|
|
||||||
|
final String url = "http://bru5:8080/cavis-rest-1.0-SNAPSHOT.war/hello/hello-world?";
|
||||||
|
final String seriesName;
|
||||||
|
|
||||||
|
public ScoreToChartListener(String seriesName) {
|
||||||
|
this.seriesName = seriesName;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void iterationDone(Model model, int iteration, int epoch) {
|
||||||
|
double score = model.score();
|
||||||
|
String nurl = url+"s="+score+"&n="+seriesName;
|
||||||
|
OkHttpClient client = new OkHttpClient();
|
||||||
|
|
||||||
|
Request request = new Request.Builder()
|
||||||
|
.url(nurl)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
try {
|
||||||
|
Response response = client.newCall(request).execute();
|
||||||
|
log.debug(String.format("Did send score to chart at '%s'.", nurl));
|
||||||
|
response.body().close();
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.warn(String.format("Could not send score to chart at '%s' because %s", nurl, e.getMessage()));
|
||||||
|
}
|
||||||
|
//response.body().string();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -37,7 +37,6 @@ public class NativeOpsGPUInfoProvider implements GPUInfoProvider {
|
||||||
|
|
||||||
List<GPUInfo> gpus = new ArrayList<>();
|
List<GPUInfo> gpus = new ArrayList<>();
|
||||||
|
|
||||||
|
|
||||||
int nDevices = nativeOps.getAvailableDevices();
|
int nDevices = nativeOps.getAvailableDevices();
|
||||||
if (nDevices > 0) {
|
if (nDevices > 0) {
|
||||||
for (int i = 0; i < nDevices; i++) {
|
for (int i = 0; i < nDevices; i++) {
|
||||||
|
|
|
@ -83,7 +83,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class AtomicAllocator implements Allocator {
|
public class AtomicAllocator implements Allocator {
|
||||||
private static final AtomicAllocator INSTANCE = new AtomicAllocator();
|
private static AtomicAllocator INSTANCE = new AtomicAllocator();
|
||||||
|
|
||||||
private Configuration configuration;
|
private Configuration configuration;
|
||||||
|
|
||||||
|
@ -122,6 +122,7 @@ public class AtomicAllocator implements Allocator {
|
||||||
private final AtomicLong useTracker = new AtomicLong(System.currentTimeMillis());
|
private final AtomicLong useTracker = new AtomicLong(System.currentTimeMillis());
|
||||||
|
|
||||||
public static AtomicAllocator getInstance() {
|
public static AtomicAllocator getInstance() {
|
||||||
|
if(INSTANCE == null) INSTANCE = new AtomicAllocator();
|
||||||
if (INSTANCE == null)
|
if (INSTANCE == null)
|
||||||
throw new RuntimeException("AtomicAllocator is NULL");
|
throw new RuntimeException("AtomicAllocator is NULL");
|
||||||
return INSTANCE;
|
return INSTANCE;
|
||||||
|
|
|
@ -402,6 +402,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
||||||
val devicePtr = allocationPoint.getDevicePointer();
|
val devicePtr = allocationPoint.getDevicePointer();
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream());
|
NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream());
|
||||||
|
int ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||||
|
if(ec != 0) {
|
||||||
|
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
|
||||||
|
}
|
||||||
ctx.getSpecialStream().synchronize();
|
ctx.getSpecialStream().synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.jcublas.buffer;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
class BaseCudaDataBufferTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMemoryAlloc() throws InterruptedException {
|
||||||
|
BaseCudaDataBuffer cuBuffer = new CudaLongDataBuffer(16l);
|
||||||
|
log.info(
|
||||||
|
"Allocation Status: " + cuBuffer.getAllocationPoint().getAllocationStatus().toString());
|
||||||
|
Thread.sleep(3000);
|
||||||
|
cuBuffer.getAllocationPoint().tickDeviceWrite();
|
||||||
|
DataBuffer buf = Nd4j.rand(8,1).shapeInfoDataBuffer();
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpySync(cuBuffer.pointer(), buf.pointer(), 8, 0, new Pointer() );
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"Allocation Status: " + cuBuffer.getAllocationPoint().getAllocationStatus().toString());
|
||||||
|
|
||||||
|
cuBuffer.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue