diff --git a/brutex-extended-tests/build.gradle b/brutex-extended-tests/build.gradle
index bd53f61bd..c15f6d325 100644
--- a/brutex-extended-tests/build.gradle
+++ b/brutex-extended-tests/build.gradle
@@ -19,8 +19,12 @@
*
*/
-apply plugin: 'java'
-apply plugin: 'maven-publish'
+plugins {
+ id 'java-library'
+ id 'maven-publish'
+ id 'com.github.johnrengelman.shadow' version '7.1.2'
+}
+
apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
@@ -54,6 +58,7 @@ dependencies {
implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkParameterserver
implementation projects.cavisDnn.cavisDnnNnParent.cavisDnnNnCore
implementation projects.cavisDnn.cavisDnnNn
+
implementation projects.cavisUi.cavisUiCommon
implementation projects.cavisUi.cavisUiVertx
implementation projects.cavisUi.cavisUiModel
@@ -66,11 +71,21 @@ dependencies {
implementation projects.cavisDnn.cavisDnnParallelwrapper
implementation projects.cavisZoo.cavisZooModels
-
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
}
+
test {
- dependsOn jar
+ enabled true
+ dependsOn shadowJar
}
+
+shadowJar {
+ enabled true;
+ zip64 true //need this to support jars with more than 65535 entries
+ archiveClassifier.set('all')
+ from sourceSets.test.output
+}
+
+
diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java
new file mode 100644
index 000000000..f4feb6fdf
--- /dev/null
+++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java
@@ -0,0 +1,279 @@
+/*
+ *
+ * ******************************************************************************
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ *
+ */
+
+package net.brutex.gan;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
+import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.layers.*;
+import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.PerformanceListener;
+import org.junit.jupiter.api.Test;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.activations.impl.ActivationLReLU;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.config.Adam;
+import org.nd4j.linalg.learning.config.IUpdater;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+
+import javax.swing.*;
+import java.awt.*;
+import java.awt.image.BufferedImage;
+import java.io.File;
+import java.util.Arrays;
+
+public class App {
+ private static final double LEARNING_RATE = 0.0002;
+ private static final double GRADIENT_THRESHOLD = 100.0;
+ private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
+
+ private static JFrame frame;
+ private static JPanel panel;
+
+ private static Layer[] genLayers() {
+ return new Layer[] {
+ new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
+ new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
+ new DenseLayer.Builder().nIn(256).nOut(512).build(),
+ new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
+ new DenseLayer.Builder().nIn(512).nOut(1024).build(),
+ new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
+ new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build()
+ };
+ }
+
+ /**
+ * Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
+ *
+ * @return config
+ */
+ private static MultiLayerConfiguration generator() {
+ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+ .seed(42)
+ .updater(UPDATER)
+ .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
+ .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
+ .weightInit(WeightInit.XAVIER)
+ .activation(Activation.IDENTITY)
+ .list(genLayers())
+ .build();
+
+ return conf;
+ }
+
+ private static Layer[] disLayers() {
+ return new Layer[]{
+ new DenseLayer.Builder().nIn(784).nOut(1024).build(),
+ new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
+ new DropoutLayer.Builder(1 - 0.5).build(),
+ new DenseLayer.Builder().nIn(1024).nOut(512).build(),
+ new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
+ new DropoutLayer.Builder(1 - 0.5).build(),
+ new DenseLayer.Builder().nIn(512).nOut(256).build(),
+ new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
+ new DropoutLayer.Builder(1 - 0.5).build(),
+ new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
+ };
+ }
+
+ private static MultiLayerConfiguration discriminator() {
+ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+ .seed(42)
+ .updater(UPDATER)
+ .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
+ .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
+ .weightInit(WeightInit.XAVIER)
+ .activation(Activation.IDENTITY)
+ .list(disLayers())
+ .build();
+
+ return conf;
+ }
+
+ private static MultiLayerConfiguration gan() {
+ Layer[] genLayers = genLayers();
+ Layer[] disLayers = Arrays.stream(disLayers())
+ .map((layer) -> {
+ if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
+ return new FrozenLayerWithBackprop(layer);
+ } else {
+ return layer;
+ }
+ }).toArray(Layer[]::new);
+ Layer[] layers = ArrayUtils.addAll(genLayers, disLayers);
+
+ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+ .seed(42)
+ .updater(UPDATER)
+ .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
+ .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
+ .weightInit(WeightInit.XAVIER)
+ .activation(Activation.IDENTITY)
+ .list(layers)
+ .build();
+
+ return conf;
+ }
+
+
+ @Test
+ public void runTest() throws Exception {
+ main();
+ }
+
+ public static void main(String... args) throws Exception {
+ Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
+
+ MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 42);
+
+ MultiLayerNetwork gen = new MultiLayerNetwork(generator());
+ MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
+ MultiLayerNetwork gan = new MultiLayerNetwork(gan());
+ gen.init();
+ dis.init();
+ gan.init();
+
+ copyParams(gen, dis, gan);
+
+ gen.setListeners(new PerformanceListener(10, true));
+ dis.setListeners(new PerformanceListener(10, true));
+ gan.setListeners(new PerformanceListener(10, true));
+
+ trainData.reset();
+
+ int j = 0;
+ for (int i = 0; i < 10; i++) {
+ while (trainData.hasNext()) {
+ j++;
+
+ // generate data
+ INDArray real = trainData.next().getFeatures().muli(2).subi(1);
+ int batchSize = (int) real.shape()[0];
+
+ INDArray fakeIn = Nd4j.rand(batchSize, 100);
+ INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
+
+ DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
+ DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
+
+ DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
+
+ dis.fit(data);
+ dis.fit(data);
+
+ // Update the discriminator in the GAN network
+ updateGan(gen, dis, gan);
+
+ gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
+
+
+ if (j % 10 == 1) {
+ System.out.println("Iteration " + j + " Visualizing...");
+ INDArray[] samples = new INDArray[9];
+ DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
+
+ for (int k = 0; k < 9; k++) {
+ INDArray input = fakeSet2.get(k).getFeatures();
+ //samples[k] = gen.output(input, false);
+ samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
+
+ }
+ visualize(samples);
+ }
+ }
+ trainData.reset();
+ }
+
+ // Copy the GANs generator to gen.
+ updateGen(gen, gan);
+
+ gen.save(new File("mnist-mlp-generator.dlj"));
+ }
+
+ private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
+ int genLayerCount = gen.getLayers().length;
+ for (int i = 0; i < gan.getLayers().length; i++) {
+ if (i < genLayerCount) {
+ gen.getLayer(i).setParams(gan.getLayer(i).params());
+ } else {
+ dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
+ }
+ }
+ }
+
+ private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
+ for (int i = 0; i < gen.getLayers().length; i++) {
+ gen.getLayer(i).setParams(gan.getLayer(i).params());
+ }
+ }
+
+ private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
+ int genLayerCount = gen.getLayers().length;
+ for (int i = genLayerCount; i < gan.getLayers().length; i++) {
+ gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).params());
+ }
+ }
+
+ private static void visualize(INDArray[] samples) {
+ if (frame == null) {
+ frame = new JFrame();
+ frame.setTitle("Viz");
+ frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
+ frame.setLayout(new BorderLayout());
+
+ panel = new JPanel();
+
+ panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
+ frame.add(panel, BorderLayout.CENTER);
+ frame.setVisible(true);
+ }
+
+ panel.removeAll();
+
+ for (INDArray sample : samples) {
+ panel.add(getImage(sample));
+ }
+
+ frame.revalidate();
+ frame.pack();
+ }
+
+ private static JLabel getImage(INDArray tensor) {
+ BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
+ for (int i = 0; i < 784; i++) {
+ int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
+ bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
+ }
+ ImageIcon orig = new ImageIcon(bi);
+ Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
+
+ ImageIcon scaled = new ImageIcon(imageScaled);
+
+ return new JLabel(scaled);
+ }
+}
\ No newline at end of file
diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java b/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java
new file mode 100644
index 000000000..25473fc9e
--- /dev/null
+++ b/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java
@@ -0,0 +1,411 @@
+/*
+ *
+ * ******************************************************************************
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ *
+ */
+
+package net.brutex.gan;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.deeplearning4j.nn.api.Layer;
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.conf.*;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.optimize.api.BaseTrainingListener;
+import org.nd4j.evaluation.classification.Evaluation;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.config.IUpdater;
+import org.nd4j.linalg.learning.config.Sgd;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+
+/**
+ * Implementation of vanilla Generative Adversarial Networks as introduced in https://arxiv.org/pdf/1406.2661.pdf.
+ *
+ * A DL4J GAN is initialized from two networks: a generator and a discriminator and will build a third network,
+ * the GAN network, from the first two.
+ *
+ * @author Max Pumperla
+ */
+public class GAN {
+ private static final IUpdater UPDATER_ZERO = Sgd.builder().learningRate(0.0).build();
+
+ public interface DiscriminatorProvider {
+ MultiLayerNetwork provide(IUpdater updater);
+ }
+
+ protected Supplier generatorSupplier;
+ protected DiscriminatorProvider discriminatorSupplier;
+
+ protected MultiLayerNetwork generator;
+ protected MultiLayerNetwork discriminator;
+ protected MultiLayerNetwork gan;
+ protected int latentDim;
+
+ protected IUpdater updater;
+ protected IUpdater biasUpdater;
+ protected OptimizationAlgorithm optimizer;
+ protected GradientNormalization gradientNormalizer;
+ protected double gradientNormalizationThreshold;
+ protected WorkspaceMode trainingWorkSpaceMode;
+ protected WorkspaceMode inferenceWorkspaceMode;
+ protected CacheMode cacheMode;
+ protected long seed;
+
+ private Double[] discriminatorLearningRates;
+
+
+ public GAN(Builder builder) {
+ this.generatorSupplier = builder.generator;
+ this.discriminatorSupplier = builder.discriminator;
+ this.latentDim = builder.latentDimension;
+ this.updater = builder.iUpdater;
+ this.biasUpdater = builder.biasUpdater;
+ this.optimizer = builder.optimizationAlgo;
+ this.gradientNormalizer = builder.gradientNormalization;
+ this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
+ this.trainingWorkSpaceMode = builder.trainingWorkspaceMode;
+ this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode;
+ this.cacheMode = builder.cacheMode;
+ this.seed = builder.seed;
+
+ defineGan();
+ }
+
+ public MultiLayerNetwork getGenerator() {
+ return generator;
+ }
+
+ public MultiLayerNetwork getDiscriminator() {
+ return discriminator;
+ }
+
+ public Evaluation evaluateGan(DataSetIterator data) {
+ return gan.evaluate(data);
+ }
+
+ public Evaluation evaluateGan(DataSetIterator data, List labelsList) {
+ return gan.evaluate(data, labelsList);
+ }
+
+
+ public void setGeneratorListeners(BaseTrainingListener[] listeners) {
+ generator.setListeners(listeners);
+ }
+
+ public void setDiscriminatorListeners(BaseTrainingListener[] listeners) {
+ discriminator.setListeners(listeners);
+ }
+
+ public void setGanListeners(BaseTrainingListener[] listeners) {
+ gan.setListeners(listeners);
+ }
+
+ public void fit(DataSetIterator realData, int numEpochs) {
+ for (int i = 0; i < numEpochs; i++) {
+ while (realData.hasNext()) {
+ // Get real images as features
+ DataSet next = realData.next();
+ fit(next);
+ }
+ realData.reset();
+ }
+ }
+
+ public void fit(DataSet next) {
+ int batchSize;
+ INDArray realImages = next.getFeatures().muli(2).subi(1);
+ batchSize = (int) realImages.shape()[0];
+
+ // Sample from latent space and let the generate create fake images.
+ INDArray randomLatentData = Nd4j.rand(new int[]{batchSize, latentDim});
+ INDArray fakeImages = generator.output(randomLatentData);
+
+ // Real images are marked as "0", fake images at "1".
+ DataSet realSet = new DataSet(realImages, Nd4j.zeros(batchSize, 1));
+ DataSet fakeSet = new DataSet(fakeImages, Nd4j.ones(batchSize, 1));
+
+ // Fit the discriminator on a combined batch of real and fake images.
+ DataSet combined = DataSet.merge(Arrays.asList(realSet, fakeSet));
+
+ /*for (int i = 0; i < discriminator.getLayers().length; i++) {
+ if (discriminatorLearningRates[i] != null) {
+ discriminator.setLearningRate(i, discriminatorLearningRates[i]);
+ }
+ }*/
+
+ discriminator.fit(combined);
+ //discriminator.fit(combined);
+
+ // Update the discriminator in the GAN network
+ updateGanWithDiscriminator();
+
+ // Generate a new set of adversarial examples and try to mislead the discriminator.
+ // by labeling the fake images as real images we reward the generator when it's output
+ // tricks the discriminator.
+ INDArray adversarialExamples = Nd4j.rand(new int[]{batchSize, latentDim});
+ INDArray misleadingLabels = Nd4j.zeros(batchSize, 1);
+ DataSet adversarialSet = new DataSet(adversarialExamples, misleadingLabels);
+
+ // Set learning rate of discriminator part of gan to zero.
+ /*for (int i = generator.getLayers().length; i < gan.getLayers().length; i++) {
+ gan.setLearningRate(i, 0.0);
+ }*/
+
+ // Fit the GAN on the adversarial set, trying to fool the discriminator by generating
+ // better fake images.
+ gan.fit(adversarialSet);
+
+ // Copy the GANs generator part to "generator".
+ updateGeneratorFromGan();
+ }
+
+ private void defineGan() {
+ generator = generatorSupplier.get();
+ generator.init();
+
+ Layer[] genLayers = generator.getLayers();
+ int numGenLayers = genLayers.length;
+
+ discriminator = discriminatorSupplier.provide(updater);
+ discriminator.init();
+
+ MultiLayerNetwork ganDiscriminator = discriminatorSupplier.provide(UPDATER_ZERO);
+ ganDiscriminator.init();
+
+ Layer[] disLayers = ganDiscriminator.getLayers();
+ Layer[] layers = ArrayUtils.addAll(genLayers, disLayers);
+ MultiLayerConfiguration genConf = generator.getLayerWiseConfigurations();
+ MultiLayerConfiguration disConf = ganDiscriminator.getLayerWiseConfigurations();
+ org.deeplearning4j.nn.conf.layers.Layer[] confLayers = new org.deeplearning4j.nn.conf.layers.Layer[layers.length];
+
+ Map preProcessors = new HashMap<>();
+ for (int i = 0; i < layers.length; i++) {
+ confLayers[i] = layers[i].conf().getLayer();
+ if (i < numGenLayers) {
+ preProcessors.put(i, genConf.getInputPreProcess(i));
+ } else {
+ preProcessors.put(i, disConf.getInputPreProcess(i - numGenLayers));
+ }
+ }
+
+ MultiLayerConfiguration ganConf = new NeuralNetConfiguration.Builder()
+ .seed(seed)
+ .updater(updater)
+ .biasUpdater(biasUpdater)
+ .optimizationAlgo(optimizer)
+ .gradientNormalization(gradientNormalizer)
+ .gradientNormalizationThreshold(gradientNormalizationThreshold)
+ .activation(Activation.IDENTITY)
+ .trainingWorkspaceMode(trainingWorkSpaceMode)
+ .inferenceWorkspaceMode(inferenceWorkspaceMode)
+ .cacheMode(cacheMode)
+ .list(confLayers)
+ .inputPreProcessors(preProcessors)
+ .build();
+ gan = new MultiLayerNetwork(ganConf);
+ gan.init();
+
+ // we lose proper init here, need to copy weights after
+ copyParamsToGan();
+ }
+
+ private void copyParamsToGan() {
+ int genLayerCount = generator.getLayers().length;
+ for (int i = 0; i < gan.getLayers().length; i++) {
+ if (i < genLayerCount) {
+ generator.getLayer(i).setParams(gan.getLayer(i).params());
+ } else {
+ discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
+ }
+ }
+ }
+
+ /**
+ * After the GAN has been trained on misleading images, we update the generator the
+ * new weights (we don't have to update the discriminator, as it is frozen in the GAN).
+ */
+ private void updateGeneratorFromGan() {
+ for (int i = 0; i < generator.getLayers().length; i++) {
+ generator.getLayer(i).setParams(gan.getLayer(i).params());
+ }
+ }
+
+ /**
+ * After the discriminator has been trained, we update the respective parts of the GAN network
+ * as well.
+ */
+ private void updateGanWithDiscriminator() {
+ int genLayerCount = generator.getLayers().length;
+ for (int i = genLayerCount; i < gan.getLayers().length; i++) {
+ gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).params());
+ }
+ }
+
+ /**
+ * GAN builder, used as a starting point for creating a MultiLayerConfiguration or
+ * ComputationGraphConfiguration.
+ */
+ public static class Builder implements Cloneable {
+ protected Supplier generator;
+ protected DiscriminatorProvider discriminator;
+ protected int latentDimension;
+
+ protected IUpdater iUpdater = new Sgd();
+ protected IUpdater biasUpdater = null;
+ protected long seed = System.currentTimeMillis();
+ protected OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
+ protected GradientNormalization gradientNormalization = GradientNormalization.None;
+ protected double gradientNormalizationThreshold = 1.0;
+
+ protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
+ protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
+ protected CacheMode cacheMode = CacheMode.NONE;
+
+
+ public Builder() {
+ }
+
+
+ /**
+ * Set the (fake) image generator of the GAN.
+ *
+ * @param generator MultilayerNetwork
+ * @return Builder
+ */
+ public GAN.Builder generator(Supplier generator) {
+ this.generator = generator;
+ return this;
+ }
+
+ /**
+ * Set the image discriminator of the GAN.
+ *
+ * @param discriminator MultilayerNetwork
+ * @return Builder
+ */
+ public GAN.Builder discriminator(DiscriminatorProvider discriminator) {
+ this.discriminator = discriminator;
+ return this;
+ }
+
+ /**
+ * Set the latent dimension, i.e. the input vector space dimension of the generator.
+ *
+ * @param latentDimension latent space input dimension.
+ * @return Builder
+ */
+ public GAN.Builder latentDimension(int latentDimension) {
+ this.latentDimension = latentDimension;
+ return this;
+ }
+
+
+ /**
+ * Random number generator seed. Used for reproducibility between runs
+ */
+ public GAN.Builder seed(long seed) {
+ this.seed = seed;
+ Nd4j.getRandom().setSeed(seed);
+ return this;
+ }
+
+ /**
+ * Optimization algorithm to use. Most common: OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT
+ *
+ * @param optimizationAlgo Optimization algorithm to use when training
+ */
+ public GAN.Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) {
+ this.optimizationAlgo = optimizationAlgo;
+ return this;
+ }
+
+
+ /**
+ * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam}
+ * or {@link org.nd4j.linalg.learning.config.Nesterovs}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different
+ * value is explicitly set on a given layer. In other words: values set via this method are used as the default
+ * value, and can be overridden on a per-layer basis.
+ *
+ * @param updater Updater to use
+ */
+ public GAN.Builder updater(IUpdater updater) {
+ this.iUpdater = updater;
+ return this;
+ }
+
+ /**
+ * Gradient updater configuration, for the biases only. If not set, biases will use the updater as
+ * set by {@link #updater(IUpdater)}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different
+ * value is explicitly set on a given layer. In other words: values set via this method are used as the default
+ * value, and can be overridden on a per-layer basis.
+ *
+ * @param updater Updater to use for bias parameters
+ */
+ public GAN.Builder biasUpdater(IUpdater updater) {
+ this.biasUpdater = updater;
+ return this;
+ }
+
+ /**
+ * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc.
+ * See {@link GradientNormalization} for details
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different
+ * value is explicitly set on a given layer. In other words: values set via this method are used as the default
+ * value, and can be overridden on a per-layer basis.
+ *
+ * @param gradientNormalization Type of normalization to use. Defaults to None.
+ * @see GradientNormalization
+ */
+ public GAN.Builder gradientNormalization(GradientNormalization gradientNormalization) {
+ this.gradientNormalization = gradientNormalization;
+ return this;
+ }
+
+ /**
+ * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
+ * GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue
+ * Not used otherwise.
+ * L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different
+ * value is explicitly set on a given layer. In other words: values set via this method are used as the default
+ * value, and can be overridden on a per-layer basis.
+ */
+ public GAN.Builder gradientNormalizationThreshold(double threshold) {
+ this.gradientNormalizationThreshold = threshold;
+ return this;
+ }
+
+ public GAN build() {
+ return new GAN(this);
+ }
+
+ }
+
+}
\ No newline at end of file
diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/GANVisualizationUtils.java b/brutex-extended-tests/src/test/java/net/brutex/gan/GANVisualizationUtils.java
new file mode 100644
index 000000000..b88a6ae8f
--- /dev/null
+++ b/brutex-extended-tests/src/test/java/net/brutex/gan/GANVisualizationUtils.java
@@ -0,0 +1,73 @@
+/*
+ *
+ * ******************************************************************************
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ *
+ */
+
+package net.brutex.gan;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+import javax.swing.*;
+import java.awt.*;
+import java.awt.image.BufferedImage;
+
+public class GANVisualizationUtils {
+
+ public static JFrame initFrame() {
+ JFrame frame = new JFrame();
+ frame.setTitle("Viz");
+ frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
+ frame.setLayout(new BorderLayout());
+ return frame;
+ }
+
+ public static JPanel initPanel(JFrame frame, int numSamples) {
+ JPanel panel = new JPanel();
+
+ panel.setLayout(new GridLayout(numSamples / 3, 1, 8, 8));
+ frame.add(panel, BorderLayout.CENTER);
+ frame.setVisible(true);
+ return panel;
+ }
+
+ public static void visualize(INDArray[] samples, JFrame frame, JPanel panel) {
+ panel.removeAll();
+
+ for (int i = 0; i < samples.length; i++) {
+ panel.add(getImage(samples[i]));
+ }
+
+ frame.revalidate();
+ frame.pack();
+ }
+
+ private static JLabel getImage(INDArray tensor) {
+ BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
+ for (int i = 0; i < 784; i++) {
+ int pixel = (int) (((tensor.getDouble(i) + 1) * 2) * 255);
+ bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
+ }
+ ImageIcon orig = new ImageIcon(bi);
+ Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
+
+ ImageIcon scaled = new ImageIcon(imageScaled);
+
+ return new JLabel(scaled);
+ }
+}
\ No newline at end of file
diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java
new file mode 100644
index 000000000..d0e5bb73d
--- /dev/null
+++ b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java
@@ -0,0 +1,193 @@
+/*
+ *
+ * ******************************************************************************
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ *
+ */
+
+package net.brutex.gan;
+
+import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
+import org.deeplearning4j.nn.conf.ConvolutionMode;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.*;
+import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
+import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.PerformanceListener;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.config.RmsProp;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+
+import javax.swing.*;
+import java.awt.*;
+import java.awt.image.BufferedImage;
+import java.util.function.Supplier;
+
+
+/**
+ * Training and visualizing a deep convolutional generative adversarial network (DCGAN) on handwritten digits.
+ *
+ * @author Max Pumperla, wmeddie
+ */
+public class MnistDCGANExample {
+
+ private static JFrame frame;
+ private static JPanel panel;
+
+ private static final int latentDim = 100;
+ private static final int height = 28;
+ private static final int width = 28;
+ private static final int channels = 1;
+
+
+ private static void visualize(INDArray[] samples) {
+ if (frame == null) {
+ frame = new JFrame();
+ frame.setTitle("Viz");
+ frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
+ frame.setLayout(new BorderLayout());
+
+ panel = new JPanel();
+
+ panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
+ frame.add(panel, BorderLayout.CENTER);
+ frame.setVisible(true);
+ }
+
+ panel.removeAll();
+
+ for (int i = 0; i < samples.length; i++) {
+ panel.add(getImage(samples[i]));
+ }
+
+ frame.revalidate();
+ frame.pack();
+ }
+
+ private static JLabel getImage(INDArray tensor) {
+ BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
+ for (int i = 0; i < 784; i++) {
+ int pixel = (int) (((tensor.getDouble(i) + 1) * 2) * 255);
+ bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
+ }
+ ImageIcon orig = new ImageIcon(bi);
+ Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
+
+ ImageIcon scaled = new ImageIcon(imageScaled);
+
+ return new JLabel(scaled);
+ }
+
+ public static void main(String[] args) throws Exception {
+ Supplier genSupplier = () -> {
+ return new MultiLayerNetwork(new NeuralNetConfiguration.Builder().list()
+ .layer(0, new DenseLayer.Builder().nIn(latentDim).nOut(width / 2 * height / 2 * 128)
+ .activation(Activation.LEAKYRELU).weightInit(WeightInit.NORMAL).build())
+ .layer(1, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5)
+ .convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
+ // Up-sampling to 28x28x256
+ .layer(2, new Deconvolution2D.Builder().nIn(128).nOut(128).stride(2, 2)
+ .kernelSize(5, 5).convolutionMode(ConvolutionMode.Same)
+ .activation(Activation.LEAKYRELU).build())
+ .layer(3, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5)
+ .convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
+ .layer(4, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5)
+ .convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
+ .layer(5, new Convolution2D.Builder().nIn(128).nOut(channels).kernelSize(7, 7)
+ .convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
+ .layer(6, new ActivationLayer.Builder().activation(Activation.TANH).build())
+ .inputPreProcessor(1,
+ new FeedForwardToCnnPreProcessor(height / 2, width / 2, 128))
+ .inputPreProcessor(6, new CnnToFeedForwardPreProcessor(height, width, channels))
+ .setInputType(InputType.feedForward(latentDim))
+ .build());
+ };
+
+ GAN.DiscriminatorProvider discriminatorProvider = (updater) -> {
+ return new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
+ .updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
+ //.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
+ //.gradientNormalizationThreshold(100.0)
+ .list()
+ .layer(0, new Convolution2D.Builder().nIn(channels).nOut(64).kernelSize(3, 3)
+ .activation(Activation.LEAKYRELU).build())
+ .layer(1, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2)
+ .activation(Activation.LEAKYRELU).build())
+ .layer(2, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2)
+ .activation(Activation.LEAKYRELU).build())
+ .layer(3, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2)
+ .activation(Activation.LEAKYRELU).build())
+ .layer(4, new DropoutLayer.Builder().dropOut(0.5).build())
+ .layer(5, new DenseLayer.Builder().nIn(64 * 2 * 2).nOut(1).activation(Activation.SIGMOID).build())
+ .layer(6, new LossLayer.Builder().lossFunction(LossFunctions.LossFunction.XENT).build())
+ .inputPreProcessor(0, new FeedForwardToCnnPreProcessor(height, width, channels))
+ .inputPreProcessor(4, new CnnToFeedForwardPreProcessor(2, 2, 64))
+ .setInputType(InputType.convolutionalFlat(height, width, channels))
+ .build());
+ };
+
+ GAN gan = new GAN.Builder()
+ .generator(genSupplier)
+ .discriminator(discriminatorProvider)
+ .latentDimension(latentDim)
+ //.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
+ //.gradientNormalizationThreshold(1.0)
+ .updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
+ .build();
+
+ gan.getGenerator().setListeners(new PerformanceListener(1, true));
+ gan.getDiscriminator().setListeners(new PerformanceListener(1, true));
+
+ Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
+
+ int batchSize = 64;
+ MnistDataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 42);
+
+ for (int i = 0; i < 10; i++) {
+ //gan.fit(trainData, 1);
+
+ System.out.println("Starting epoch: " + (i + 1));
+
+ trainData.reset();
+ int j = 0;
+ while (trainData.hasNext()) {
+ DataSet next = trainData.next();
+ gan.fit(next);
+
+ if (j % 1 == 0) {
+ System.out.println("Epoch " + (i + 1) + " iteration " + j + " Visualizing...");
+ INDArray fakeIn = Nd4j.rand(new int[]{batchSize, latentDim});
+
+ INDArray[] samples = new INDArray[9];
+ for (int k = 0; k < 9; k++) {
+ samples[k] = gan.getGenerator().output(fakeIn.getRow(k), false);
+ }
+ visualize(samples);
+ }
+ j++;
+ }
+
+ System.out.println("Finished epoch: " + (i + 1));
+ }
+ }
+}
diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java
new file mode 100644
index 000000000..037a0be9d
--- /dev/null
+++ b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java
@@ -0,0 +1,146 @@
+/*
+ *
+ * ******************************************************************************
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ *
+ */
+
+package net.brutex.gan;
+
+import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
+import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.layers.ActivationLayer;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.conf.layers.DropoutLayer;
+import org.deeplearning4j.nn.conf.layers.OutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.activations.impl.ActivationLReLU;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.config.Adam;
+import org.nd4j.linalg.learning.config.IUpdater;
+import org.nd4j.linalg.learning.config.Sgd;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+
+import javax.swing.*;
+
+
+/**
+ * Relatively small GAN example using only Dense layers with dropout to generate handwritten
+ * digits from MNIST data.
+ */
+public class MnistSimpleGAN {
+
+ private static final int LATENT_DIM = 100;
+
+ private static final double LEARNING_RATE = 0.0002;
+ private static final IUpdater UPDATER_ZERO = Sgd.builder().learningRate(0.0).build();
+ private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
+
+
+ public static MultiLayerNetwork getGenerator() {
+ MultiLayerConfiguration genConf = new NeuralNetConfiguration.Builder()
+ .weightInit(WeightInit.XAVIER)
+ .activation(Activation.IDENTITY)
+ .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
+ .gradientNormalizationThreshold(100)
+ .list()
+ .layer(new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build())
+ .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
+ .layer(new DenseLayer.Builder().nIn(256).nOut(512).build())
+ .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
+ .layer(new DenseLayer.Builder().nIn(512).nOut(1024).build())
+ .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
+ .layer(new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build())
+ .build();
+ return new MultiLayerNetwork(genConf);
+ }
+
+
+ public static MultiLayerNetwork getDiscriminator(IUpdater updater) {
+ MultiLayerConfiguration discConf = new NeuralNetConfiguration.Builder()
+ .seed(42)
+ .updater(updater)
+ .weightInit(WeightInit.XAVIER)
+ .activation(Activation.IDENTITY)
+ .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
+ .gradientNormalizationThreshold(100)
+ .list()
+ .layer(new DenseLayer.Builder().nIn(784).nOut(1024).updater(updater).build())
+ .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
+ .layer(new DropoutLayer.Builder(1 - 0.5).build())
+ .layer(new DenseLayer.Builder().nIn(1024).nOut(512).updater(updater).build())
+ .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
+ .layer(new DropoutLayer.Builder(1 - 0.5).build())
+ .layer(new DenseLayer.Builder().nIn(512).nOut(256).updater(updater).build())
+ .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
+ .layer(new DropoutLayer.Builder(1 - 0.5).build())
+ .layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1)
+ .activation(Activation.SIGMOID).updater(updater).build())
+ .build();
+
+ return new MultiLayerNetwork(discConf);
+ }
+
+ public static void main(String[] args) throws Exception {
+ GAN gan = new GAN.Builder()
+ .generator(MnistSimpleGAN::getGenerator)
+ .discriminator(MnistSimpleGAN::getDiscriminator)
+ .latentDimension(LATENT_DIM)
+ .seed(42)
+ .updater(UPDATER)
+ .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
+ .gradientNormalizationThreshold(100)
+ .build();
+
+ Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
+
+ int batchSize = 128;
+ MnistDataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 42);
+
+
+ // Sample from latent space once to visualize progress on image generation.
+ int numSamples = 9;
+ JFrame frame = GANVisualizationUtils.initFrame();
+ JPanel panel = GANVisualizationUtils.initPanel(frame, numSamples);
+
+ for (int i = 0; i < 100; i++) {
+ trainData.reset();
+ int j = 0;
+ while (trainData.hasNext()) {
+ gan.fit(trainData.next());
+ //gan.fit(trainData, 1);
+
+ if (j % 10 == 0) {
+ INDArray fakeIn = Nd4j.rand(new int[]{batchSize, LATENT_DIM});
+ System.out.println("Epoch " + (i + 1) + " Iteration " + j + " Visualizing...");
+ INDArray[] samples = new INDArray[numSamples];
+ for (int k = 0; k < numSamples; k++) {
+ INDArray input = fakeIn.getRow(k);
+ samples[k] = gan.getGenerator().output(input, false);
+ }
+ GANVisualizationUtils.visualize(samples, frame, panel);
+ }
+ j++;
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BaseSparkSessionTest.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BaseSparkSessionTest.java
index 5f81489e0..3b1b36c72 100644
--- a/brutex-extended-tests/src/test/java/net/brutex/spark/BaseSparkSessionTest.java
+++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BaseSparkSessionTest.java
@@ -20,48 +20,103 @@
package net.brutex.spark;
+import java.io.File;
+import java.io.IOException;
+import java.net.URI;
+import java.util.EnumSet;
import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.io.FileUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.CreateFlag;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileContext;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Options.CreateOpts;
import org.apache.spark.SparkConf;
+import org.apache.spark.SparkContext;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import java.io.Serializable;
+import org.junit.jupiter.api.Test;
@Slf4j
public abstract class BaseSparkSessionTest implements Serializable {
- private static SparkSession spark;
- public static SparkSession getSession() {
- SparkConf sparkConf = new SparkConf()
- .setMaster("spark://10.5.5.200:7077")
- .setAppName(BaseSparkSessionTest.class.getSimpleName())
- .set("spark.driver.bindAddress", "10.5.5.145")
- .set("spark.network.timeout", "240000")
- .set("spark.driver.host", "10.5.5.145")
- .set("spark.deploy.mode", "client")
- .set("spark.executor.memory", "4g")
- .set("spark.cores.max", "4")
- .set("spark.worker.cleanup.enabled", "true")
- .set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
- .set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
- .set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000");
+ private static SparkSession spark;
- spark = SparkSession.builder()
- .config(sparkConf)
- .getOrCreate();
+ public static SparkSession getSession() {
+ final String jarPath = uploadToHdfs("./build/libs/brutex-extended-tests-1.0.0-SNAPSHOT-all.jar");
- return spark;
+ SparkConf sparkConf = new SparkConf()
+ .setMaster("spark://10.5.5.200:7077")
+ .setAppName(BaseSparkSessionTest.class.getSimpleName())
+ .set("spark.driver.bindAddress", "10.5.5.145")
+ .set("spark.blockManager.port", "65001")
+ //.set("spark.driver.bindAddress", "0.0.0.0")
+ .set("spark.network.timeout", "240000")
+ .set("spark.driver.host", "10.5.5.145")
+ .set("spark.deploy.mode", "cluster")
+ .set("spark.executor.memory", "4g")
+ .set("spark.cores.max", "4")
+ .set("spark.worker.cleanup.enabled", "true")
+ .set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
+ .set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
+ .set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000")
+ //.set("spark.jars", jarPath)
+ ;
+ spark = SparkSession.builder()
+ .config(sparkConf)
+ .getOrCreate();
+
+ spark.sparkContext().addJar(jarPath);
+ return spark;
+ }
+ public static String uploadToHdfs(String jarFile) {
+ File f = new File(jarFile);
+ if(!f.exists() && !f.isFile()) throw new RuntimeException("File to upload does not exist.");
+ final String base = "hdfs://10.5.5.200:9000/";
+ String targetPath = "/user/brian/" + f.getName();
+ try {
+ Configuration conf = new Configuration();
+
+ //FileContext hdfs = FileContext.getFileContext(URI.create(base), conf);
+ org.apache.hadoop.fs.FileSystem hdfs = FileSystem.get(URI.create(base), conf);
+ //String file = SparkFiles.get("phpMawTba");
+
+ org.apache.hadoop.fs.Path target = new org.apache.hadoop.fs.Path(targetPath);
+
+ try {
+ hdfs.delete(target, false);
+ } catch (Exception e) {};
+
+ FileUtil.copy(f, hdfs, target, false, conf);
+ //Apache Commons
+ //FileUtils.copyFile(f, fTarget);
+ } catch(IOException ioe) {
+ ioe.printStackTrace();
+ }
+ return base + targetPath;
}
- @BeforeAll
- public static void beforeAll() {
- }
+ @BeforeAll
+ public static void beforeAll() {
- @AfterAll
- public static synchronized void afterAll() {
- getSession().close();
+ }
- }
+ @AfterAll
+ public static synchronized void afterAll() {
+ getSession().close();
+
+ }
+
+ @Test
+ public void testSessionCreation() {
+ SparkSession session = getSession();
+ log.info("Spark {} session id: {}", session.version(), session.sessionUUID());
+
+ }
}
diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java
index cc88a0914..efb54aa29 100644
--- a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java
+++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java
@@ -20,22 +20,34 @@
*/
package net.brutex.spark;
-import com.fasterxml.jackson.core.Version;
+import java.io.IOException;
import lombok.extern.slf4j.Slf4j;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
-import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.ForeachFunction;
import org.apache.spark.api.java.function.Function;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.filter.FilterInvalidValues;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.Writable;
+import org.datavec.spark.transform.Normalization;
import org.datavec.spark.transform.SparkTransformExecutor;
import org.datavec.spark.transform.misc.StringToWritablesFunction;
+import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
+import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator.Set;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
@@ -47,7 +59,6 @@ import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.datavec.DataVecDataSetFunction;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
-import org.deeplearning4j.ui.api.UIServer;
import org.junit.jupiter.api.*;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
@@ -56,7 +67,6 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
-import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
@@ -70,23 +80,76 @@ import java.util.Random;
@Slf4j
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@Tag("integration")
-public class BrianTest /*extends BaseDL4JTest*/ {
- static {
- String OS = System.getProperty("os.name").toLowerCase();
+public class BrianTest extends BaseSparkSessionTest {
+/*
+ static {
+ String OS = System.getProperty("os.name").toLowerCase();
- if (OS.contains("win")) {
- System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
- } else {
- System.setProperty("hadoop.home.dir", "/");
- }
+ if (OS.contains("win")) {
+ System.setProperty("hadoop.home.dir",
+ Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
+ } else {
+ System.setProperty("hadoop.home.dir", "/");
}
+ }
+*/
+ private JavaSparkContext sc;
+ private JavaRDD rdd;
- public long getTimeoutMilliseconds() {
- return 400000L;
- }
- private JavaSparkContext sc;
- private JavaRDD rdd;
+ @Test
+ public void wrapEmnitDataset() throws IOException, InterruptedException {
+ SparkSession sc = getSession();
+ EmnistDataSetIterator dataset = new EmnistDataSetIterator(Set.BALANCED, 128, true);
+ DataSet ds = dataset.next();
+ System.out.println( "Number of features " + ds.numInputs());
+ System.out.println( "Number of samples " + ds.numExamples());
+ System.out.println( "Outcomes " + ds.numOutcomes());
+ final String oppsFile = uploadToHdfs("c:/temp/opps.csv");
+
+ //System.out.println( "Reading file from " + oppsFile);
+
+ JavaRDD rdd = sc.sparkContext().textFile(oppsFile, 1)
+ .toJavaRDD();
+ System.out.println("Count " + rdd.count());
+ //while(true) Thread.sleep(1000);
+
+ //rdd.foreach( s -> {
+ // System.out.println("* "+s);
+ // });
+
+
+ //JavaRDD rdd2 = rdd.flatMap( s -> Arrays.asList( s.split(";")).iterator() );
+ //rdd2.collect().forEach( a -> System.out.print("# " + a + " ") );
+
+ StructType struct = new StructType(Arrays.asList(
+ StructField.apply("stage", DataTypes.StringType, false, Metadata.empty()),
+ StructField.apply("period", DataTypes.StringType, false, Metadata.empty()),
+ StructField.apply("portfolio", DataTypes.StringType, false, Metadata.empty()),
+ StructField.apply("country", DataTypes.StringType, false, Metadata.empty()),
+ StructField.apply("lfr", DataTypes.StringType, false, Metadata.empty()),
+ StructField.apply("saas", DataTypes.StringType, false, Metadata.empty())
+ ).toArray(new StructField[]{})
+ );
+ JavaRDD rdd3 = rdd.map( attributes -> RowFactory.create(attributes.split(";")));
+
+ Dataset frame = sc.createDataFrame(rdd3, struct);
+ Dataset frame2 = frame.select(frame.col("lfr").cast(DataTypes.FloatType));
+ frame.show(200);
+
+ // frame.collect().map(row -> System.out.println(row.fieldIndex("stage") + row.fieldIndex("country")));
+
+
+
+ //frame.agg( frame.col("stage"), frame.col("lfr"));
+ frame.foreach((ForeachFunction) s -> System.out.println(s));
+
+ //sc.read().csv(rdd2);
+ //Normalization normalization = Normalization.zeromeanUnitVariance()
+ //sc.
+
+ }
+
/*
@BeforeAll
@@ -109,120 +172,53 @@ public class BrianTest /*extends BaseDL4JTest*/ {
}
*/
- @BeforeAll
- public void setUp() throws Exception {
- log.info("Running @BeforeEach scope");
- System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
- Version version = com.fasterxml.jackson.databind.cfg.PackageVersion.VERSION;
- System.out.println("Jackson version found: " + version);
- SparkConf sparkConf = new SparkConf()
- .setMaster("spark://10.5.5.200:7077")
- .setAppName("Brian3")
- .set("spark.driver.bindAddress", "10.5.5.145")
- .set("spark.network.timeout", "240000")
- .set("spark.driver.host", "10.5.5.145")
- .set("spark.driver.bindAddress", "10.5.5.145")
- .set("spark.deploy.mode", "cluster")
- .set("spark.executor.memory", "2g")
- .set("spark.executor.cores", "2")
- .set("spark.cores.max", "4")
- .set("spark.worker.cleanup.enabled", "false")
- .set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
- .set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
- .set("spark.driver.extraClassPath", "brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar;brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar")
- .set("spark.executor.extraClassPath", "brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar;brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar")
- .set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000");
- //.set("spark.driver.cores", "2")
- //.set("spark.driver.memory", "8g")
- //.set("spark.driver.host", "10.5.5.145")
- //.setExecutorEnv("spark.executor.cores", "2")
- //.setExecutorEnv("spark.executor.memory", "2g")
- //.set("spark.submit.deployMode", "client")
-/*
- SparkSession spark = SparkSession
- .builder()
- .master("spark://10.5.5.200:7077")
- .config("spark.driver.bindAddress", "10.5.5.145")
- .config("spark.driver.host", "10.5.5.145")
- //.config("spark.driver.memory", "5g")
- .appName("BrianTest2")
- .getOrCreate();
-*/
- sc = new JavaSparkContext(sparkConf);
+ @Test
+ ////@Ignore("AB 2019/05/21 - Failing - Issue #7657")
+ public void testStringsTokenization1() throws Exception {
- // sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\deeplearning4j\\deeplearning4j-scaleout\\spark\\dl4j-spark-nlp-java8\\target\\dl4j-spark-nlp-java8_2.12-1.0.0-SNAPSHOT-tests.jar");
- // sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\datavec\\datavec-api\\target\\datavec-api-1.0.0-SNAPSHOT.jar");
- // sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\nd4j\\nd4j-uberjar\\target\\nd4j-uberjar-1.0.0-SNAPSHOT.jar");
- // sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\nd4j\\nd4j-common\\target\\nd4j-common-1.0.0-SNAPSHOT.jar");
- // sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\datavec\\datavec-spark\\target\\datavec-spark_2.12-1.0.0-SNAPSHOT.jar");
- sc.addJar("C:\\Users\\brian\\_projects\\Brian-Spark-DL4J-Tests\\target\\brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar");
- sc.addJar("C:\\Users\\brian\\_projects\\Brian-Spark-DL4J-Tests\\target\\brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar");
+ //shrink for Test
+ //List list = Arrays.asList(new String[]{"asdsad", "asdasdasd", "asdasdasd", "3easdasd"});
+ //JavaRDD rdd = sc.parallelize(list);
+ // rdd = rdd.sample(true, 1.0, 1);
+ log.info("Datenmenge: " + rdd.count());
+ log.info("Sample: " + rdd.top(3));
- rdd = sc.textFile("hdfs://10.5.5.200:9000/user/zeppelin/cities_full.csv.gz");
+ Assertions.assertEquals(146889, rdd.count());
+ }
+ @Test
+ public void testSchemaCreation() throws Exception {
+ rdd.cache();
+ JavaRDD cities = rdd.map((Function) line -> {
+ return line.split(",")[1];
+ }).cache();
- }
+ JavaRDD stateCodeList = rdd.map((Function) line -> {
+ return line.split(",")[2];
+ }).cache();
- @AfterAll
- public void tearDown() throws Exception {
- sc.close();
- sc.stop();
- UIServer.stopInstance();
+ JavaRDD countryCodeList = rdd.map((Function) line -> {
+ return line.split(",")[3];
+ }).cache();
- }
+ CSVRecordReader recordReader = new CSVRecordReader(0, ',');
+ JavaRDD> convertedRDD = rdd.map((Function>) s -> {
+ return new StringToWritablesFunction(recordReader).call(s);
+ });
- @Test
- ////@Ignore("AB 2019/05/21 - Failing - Issue #7657")
- public void testStringsTokenization1() throws Exception {
+ //Source Schema
+ Schema inputSchema = new Schema.Builder()
+ .addColumnLong("city_id")
+ .addColumnsString("city_name", "state_code", "country_code")
+ .addColumnsString("country_full")
+ .addColumnsDouble("lat", "lon")
+ .build();
- //shrink for Test
- //List list = Arrays.asList(new String[]{"asdsad", "asdasdasd", "asdasdasd", "3easdasd"});
- //JavaRDD rdd = sc.parallelize(list);
-
- // rdd = rdd.sample(true, 1.0, 1);
- log.info("Datenmenge: " + rdd.count());
- log.info("Sample: " + rdd.top(3));
-
- Assertions.assertEquals(146889, rdd.count());
- }
-
- @Test
- public void testSchemaCreation() throws Exception {
-
-
- rdd.cache();
-
- JavaRDD cities = rdd.map( (Function) line -> {
- return line.split(",")[1];
- }).cache();
-
- JavaRDD stateCodeList = rdd.map( (Function) line -> {
- return line.split(",")[2];
- }).cache();
-
- JavaRDD countryCodeList = rdd.map( (Function) line -> {
- return line.split(",")[3];
- }).cache();
-
-
- CSVRecordReader recordReader = new CSVRecordReader(0, ',');
- JavaRDD> convertedRDD = rdd.map((Function>) s -> {
- return new StringToWritablesFunction( recordReader).call(s);
- });
-
- //Source Schema
- Schema inputSchema = new Schema.Builder()
- .addColumnLong("city_id")
- .addColumnsString("city_name", "state_code", "country_code")
- .addColumnsString("country_full")
- .addColumnsDouble("lat", "lon")
- .build();
-
- //Running Transformation
+ //Running Transformation
/*
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.removeColumns("country_full", "lat", "lon")
@@ -236,38 +232,40 @@ public class BrianTest /*extends BaseDL4JTest*/ {
.categoricalToOneHot("country_code")
.build();
*/
- TransformProcess tp = new TransformProcess.Builder(inputSchema)
- .removeAllColumnsExceptFor("country_code", "lat", "lon")
- .stringToCategorical("country_code", Arrays.asList("GR", "FR", "DE", "CH"))
- .filter(new FilterInvalidValues())
- .categoricalToOneHot("country_code")
- .build();
+ TransformProcess tp = new TransformProcess.Builder(inputSchema)
+ .removeAllColumnsExceptFor("country_code", "lat", "lon")
+ .stringToCategorical("country_code", Arrays.asList("GR", "FR", "DE", "CH"))
+ .filter(new FilterInvalidValues())
+ .categoricalToOneHot("country_code")
+ .build();
- //log.info("Final Schema: " +tp.getFinalSchema().toString());
- //Execute Transformation Process
- convertedRDD.repartition(8);
- convertedRDD.cache();
- JavaRDD> processedData = SparkTransformExecutor.execute(convertedRDD, tp);
- processedData.repartition(8);
- processedData.cache();
- //log.info("Datenmenge nach processing: " + processedData.count());
+ //log.info("Final Schema: " +tp.getFinalSchema().toString());
+ //Execute Transformation Process
+ convertedRDD.repartition(8);
+ convertedRDD.cache();
+ JavaRDD> processedData = SparkTransformExecutor.execute(convertedRDD, tp);
+ processedData.repartition(8);
+ processedData.cache();
+ //log.info("Datenmenge nach processing: " + processedData.count());
+ //Vectorisieren
+ int labelIndex = 0; //in welcher Spalte ist das Label
+ int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte
- //Vectorisieren
- int labelIndex = 0; //in welcher Spalte ist das Label
- int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte
+ DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels,
+ false);
+ JavaRDD rddDataSet = processedData.map(datavecFunction);
+ log.info("rddDataset: " + rddDataSet.toDebugString());
+ Random rand = new Random();
+ rddDataSet.sortBy((Function) s -> {
+ return rand.nextDouble();
+ }, true, 8);
- DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels, false);
- JavaRDD rddDataSet = processedData.map(datavecFunction);
- log.info("rddDataset: " + rddDataSet.toDebugString());
- Random rand = new Random();
- rddDataSet.sortBy( (Function) s -> {return rand.nextDouble(); }, true, 8);
+ //og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect());
- //og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect());
-
- /* Skip, this will save each record one by one to hdfs
- */
- //Now save this hard work
+ /* Skip, this will save each record one by one to hdfs
+ */
+ //Now save this hard work
/*
int miniBatchSize = 1; //Minibatch size of the saved DataSet objects
final String exportPath = "hdfs://10.5.5.200:9000/user/brian/data";
@@ -278,63 +276,67 @@ public class BrianTest /*extends BaseDL4JTest*/ {
paths.collect();
*/
- //Create Trainingmaster
+ //Create Trainingmaster
- TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4)
- .rddTrainingApproach(RDDTrainingApproach.Direct) //when "export", tries to save everything first
- .batchSizePerWorker(1000)
- .collectTrainingStats(true)
- .build();
+ TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4)
+ .rddTrainingApproach(
+ RDDTrainingApproach.Direct) //when "export", tries to save everything first
+ .batchSizePerWorker(1000)
+ .collectTrainingStats(true)
+ .build();
- //Define Network
+ //Define Network
- MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder()
- .seed(123)
- .updater(new Nesterovs(0.1, 0.9))
- .list()
- .layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build())
- .layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
- //.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
- .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4).weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
- .build();
+ MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder()
+ .seed(123)
+ .updater(new Nesterovs(0.1, 0.9))
+ .list()
+ .layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER)
+ .activation(Activation.RELU).l2(0.001).build())
+ .layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER)
+ .activation(Activation.RELU).build())
+ //.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
+ .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4)
+ .weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
+ .build();
- //Define SparkNet
- SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration, trainingMaster);
+ //Define SparkNet
+ SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration,
+ trainingMaster);
+ JavaRDD[] split = rddDataSet.randomSplit(new double[]{0.9, 0.1}, 123);
+ //JavaRDD trainingData = split[0];
+ JavaRDD trainingData = rddDataSet;
+ JavaRDD testData = split[1];
- JavaRDD[] split = rddDataSet.randomSplit(new double[] {0.9, 0.1}, 123);
- //JavaRDD trainingData = split[0];
- JavaRDD trainingData = rddDataSet;
- JavaRDD testData = split[1];
-
- //Run Training on subset
- for(int i =0; i<20; i++) {
- sparkNet.fit(trainingData);
- }
-
- //Evaluieren
- MultiLayerNetwork finalNet = sparkNet.getNetwork();
-
- //Speichern
- Configuration conf = sc.hadoopConfiguration();
- conf.set("hadoop.tmp.dir", "/user/brian/tmp");
- FileSystem fs = FileSystem.get(conf);
- Path p = new Path("hdfs://10.5.5.200:9000/user/brian/model");
- //fs.mkdirs(p);
- //ModelSerializer.writeModel(finalNet, fs.create(p), true );
-
- Evaluation eval = new Evaluation(4); // outputNum = 10: number of output classes
- Iterator iter = testData.toLocalIterator();
- log.info("testData has " + testData.count() + " DataSets");
- while(iter.hasNext()){
- DataSet next = iter.next();
- //log.info("getFeatures " + next.getFeatures() );
- INDArray output = finalNet.output(next.getFeatures()); //get the networks prediction
- //log.info("output "+ output.toStringFull());
- eval.eval(next.getLabels(), output); //check the prediction against the true class
- //log.info("Predict " + finalNet.predict(next));
- }
- log.info("Evaluation stats: " + eval.stats());
+ //Run Training on subset
+ for (int i = 0; i < 20; i++) {
+ sparkNet.fit(trainingData);
}
+ //Evaluieren
+ MultiLayerNetwork finalNet = sparkNet.getNetwork();
+
+ //Speichern
+ Configuration conf = sc.hadoopConfiguration();
+ conf.set("hadoop.tmp.dir", "/user/brian/tmp");
+ FileSystem fs = FileSystem.get(conf);
+ Path p = new Path("hdfs://10.5.5.200:9000/user/brian/model");
+ //fs.mkdirs(p);
+ //ModelSerializer.writeModel(finalNet, fs.create(p), true );
+
+ Evaluation eval = new Evaluation(4); // outputNum = 10: number of output classes
+ Iterator iter = testData.toLocalIterator();
+ log.info("testData has " + testData.count() + " DataSets");
+ while (iter.hasNext()) {
+ DataSet next = iter.next();
+ //log.info("getFeatures " + next.getFeatures() );
+ INDArray output = finalNet.output(next.getFeatures()); //get the networks prediction
+ //log.info("output "+ output.toStringFull());
+ eval.eval(next.getLabels(), output); //check the prediction against the true class
+ //log.info("Predict " + finalNet.predict(next));
+ }
+ log.info("Evaluation stats: " + eval.stats());
+ }
+
}
diff --git a/cavis-common-platform/build.gradle b/cavis-common-platform/build.gradle
index aaf070d84..a1a728508 100644
--- a/cavis-common-platform/build.gradle
+++ b/cavis-common-platform/build.gradle
@@ -25,8 +25,8 @@ ext {
def flatbuffers = [version: "1.10.0"]
- def spark = [version: "3.1.2"]
- def scala = [version:"2.12.10"] //[version:"2.13.5"]
+ def spark = [version: "3.2.2"]
+ def scala = [version:"2.12.15"] //[version:"2.13.5"]
def netty = [version: "4.1.68.Final"]
diff --git a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldInterface.java b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/AbstractField.java
similarity index 51%
rename from cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldInterface.java
rename to cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/AbstractField.java
index 92705bea8..87aa389ce 100644
--- a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldInterface.java
+++ b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/AbstractField.java
@@ -21,59 +21,44 @@
package net.brutex.cavis.dvec.api;
-import java.io.Serializable;
import java.nio.Buffer;
-import java.nio.LongBuffer;
-import java.util.List;
+import java.nio.ByteBuffer;
import net.brutex.cavis.dvec.api.exceptions.DVecException;
/**
- * A Field can be considered a "column" in a {@code Record}, as such a Field may refer to multiple
- * entries of that "column". Fields are typed as Buffers. Some of them defined in the dvec core api,
- * other (i.e. Image or Arrow) require dvec extensions accordingly.
+ * Abtract implementation of the Field interface {@see FieldInterface}, that handles all data storage
+ * in memory and adds basic error handling.
*
* @author Brian Rosenberger
* @since 1.0
*/
-public interface FieldInterface extends Serializable {
+public abstract class AbstractField implements Field {
/**
- * Get a reference to the metadata for this Field.
- *
- * @return the {@link FieldMetadata}
- */
- FieldMetadata getFieldMetadata();
-
- /**
- * Get the 1st field as Buffer. This deserializes the data from the underlying storage.
- *
- * @return T underlying Buffer
- */
- default T read() throws DVecException {
- return read(0, 1);
- }
-
- /**
- * Get a range of fields as a {@code Buffer}
+ * {@inheritDoc}
*
* @param start Index of starting position, zero based
* @param length how many fields to read
- * @return the buffers
+ * @return the list of Buffer
*/
- T read(long start, long length) throws DVecException;
-
- /**
- * Write the data into the underlying storage.
- */
- default void write(T buffer) {
- write(0, buffer);
+ @Override
+ public T read(long start, long length) throws DVecException {
+ if (start<0 || start>internalStorage.capacity()-1 ) {
+ throw new DVecException("Read on Field start position is out of bounds.");
+ }
+ if (start+length> internalStorage.capacity()) {
+ throw new DVecException("Read on Field exceeds field length");
+ }
+ return null;
}
- /**
- * Write the data into the underyling storage starting at a position
- *
- * @param pos the position to start
- */
- void write(long pos, T buffer);
+ @Override
+ public void write(long pos, T buffer) {
+
+ }
+
+ private ByteBuffer internalStorage = null;
+
+
}
diff --git a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java
index a3be6313f..ace9be2a1 100644
--- a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java
+++ b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java
@@ -21,46 +21,57 @@
package net.brutex.cavis.dvec.api;
+import java.io.Serializable;
import java.nio.Buffer;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.List;
import net.brutex.cavis.dvec.api.exceptions.DVecException;
/**
- * Abtract implementation of the Field interface {@see FieldInterface}, that handles all data storage
- * in memory and adds basic error handling.
+ * A Field can be considered a "column" in a {@code Record}, as such a Field may refer to multiple
+ * entries of that "column". Fields are typed as Buffers. Some of them defined in the dvec core api,
+ * other (i.e. Image or Arrow) require dvec extensions accordingly.
*
* @author Brian Rosenberger
* @since 1.0
*/
-public abstract class Field implements FieldInterface {
+public interface Field extends Serializable {
/**
- * {@inheritDoc}
+ * Get a reference to the metadata for this Field.
+ *
+ * @return the {@link FieldMetadata}
+ */
+ FieldMetadata getFieldMetadata();
+
+ /**
+ * Get the 1st field as Buffer. This deserializes the data from the underlying storage.
+ *
+ * @return T underlying Buffer
+ */
+ default T read() throws DVecException {
+ return read(0, 1);
+ }
+
+ /**
+ * Get a range of fields as a {@code Buffer}
*
* @param start Index of starting position, zero based
* @param length how many fields to read
- * @return the list of Buffer
+ * @return the buffers
*/
- @Override
- public T read(long start, long length) throws DVecException {
- if (start<0 || start>internalStorage.capacity()-1 ) {
- throw new DVecException("Read on Field start position is out of bounds.");
- }
- if (start+length> internalStorage.capacity()) {
- throw new DVecException("Read on Field exceeds field length");
- }
- return null;
+ T read(long start, long length) throws DVecException;
+
+ /**
+ * Write the data into the underlying storage.
+ */
+ default void write(T buffer) {
+ write(0, buffer);
}
- @Override
- public void write(long pos, T buffer) {
-
- }
-
- private ByteBuffer internalStorage = null;
-
-
+ /**
+ * Write the data into the underyling storage starting at a position
+ *
+ * @param pos the position to start
+ */
+ void write(long pos, T buffer);
}
diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java
index 83cba9988..16f6f134a 100644
--- a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java
+++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java
@@ -99,9 +99,13 @@ public class DL4JClassLoading {
.asSubclass(superclass)
.getDeclaredConstructor(parameterTypes)
.newInstance(args);
- } catch (InstantiationException | IllegalAccessException | InvocationTargetException
+ } catch (InstantiationException | IllegalAccessException
| NoSuchMethodException instantiationException) {
log.error(String.format("Cannot create instance of class '%s'.", className), instantiationException);
+
+ throw new RuntimeException(instantiationException);
+ } catch (InvocationTargetException instantiationException) {
+ log.error(String.format("InvocationTargetException was '%s'.", instantiationException.getTargetException().getMessage()), instantiationException);
throw new RuntimeException(instantiationException);
}
}
diff --git a/cavis-dnn/cavis-dnn-cudnn/build.gradle b/cavis-dnn/cavis-dnn-cudnn/build.gradle
new file mode 100644
index 000000000..725ca1f85
--- /dev/null
+++ b/cavis-dnn/cavis-dnn-cudnn/build.gradle
@@ -0,0 +1,23 @@
+
+apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
+
+ext {
+ buildTarget = rootProject.ext.buildTarget
+}
+
+dependencies {
+ implementation platform(projects.cavisCommonPlatform)
+ implementation projects.cavisNative.cavisNativeJcublas
+ implementation projects.cavisDnn.cavisDnnApi
+ implementation projects.cavisDnn.cavisDnnNn
+
+ implementation group: "org.bytedeco", name: "cuda"
+ implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget
+ implementation group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
+
+ implementation group: "org.bytedeco", name: "javacpp"
+ implementation group: "org.bytedeco", name: "javacpp", classifier: buildTarget
+
+ implementation 'com.jakewharton.byteunits:byteunits:0.9.1'
+
+}
\ No newline at end of file
diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/BaseCudnnHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/BaseCudnnHelper.java
new file mode 100644
index 000000000..5465f6224
--- /dev/null
+++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/BaseCudnnHelper.java
@@ -0,0 +1,252 @@
+/*
+ * ******************************************************************************
+ * *
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.deeplearning4j.cuda;
+
+import lombok.NonNull;
+import lombok.extern.slf4j.Slf4j;
+import org.bytedeco.javacpp.*;
+import org.nd4j.jita.allocator.impl.AtomicAllocator;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.factory.Nd4j;
+
+import org.bytedeco.cuda.cudnn.*;
+import static org.bytedeco.cuda.global.cudart.*;
+import static org.bytedeco.cuda.global.cudnn.*;
+
+/**
+ * Functionality shared by all cuDNN-based helpers.
+ *
+ * @author saudet
+ */
+@Slf4j
+public abstract class BaseCudnnHelper {
+
+ /* public BaseCudnnHelper() {
+
+ }
+ */
+
+ protected static void checkCuda(int error) {
+ if (error != cudaSuccess) {
+ throw new RuntimeException("CUDA error = " + error + ": " + cudaGetErrorString(error).getString());
+ }
+ }
+
+ protected static void checkCudnn(int status) {
+ if (status != CUDNN_STATUS_SUCCESS) {
+ throw new RuntimeException("cuDNN status = " + status + ": " + cudnnGetErrorString(status).getString());
+ }
+ }
+
+ protected static class CudnnContext extends cudnnContext {
+
+ protected static class Deallocator extends CudnnContext implements Pointer.Deallocator {
+ Deallocator(CudnnContext c) {
+ super(c);
+ }
+
+ @Override
+ public void deallocate() {
+ destroyHandles();
+ }
+ }
+
+ public CudnnContext() {
+ // insure that cuDNN initializes on the same device as ND4J for this thread
+ Nd4j.create(1);
+ AtomicAllocator.getInstance();
+ // This needs to be called in subclasses:
+ // createHandles();
+ // deallocator(new Deallocator(this));
+ }
+
+ public CudnnContext(CudnnContext c) {
+ super(c);
+ }
+
+ protected void createHandles() {
+ checkCudnn(cudnnCreate(this));
+ }
+
+ protected void destroyHandles() {
+ checkCudnn(cudnnDestroy(this));
+ }
+ }
+
+ protected static class DataCache extends Pointer {
+
+ static class Deallocator extends DataCache implements Pointer.Deallocator {
+ Deallocator(DataCache c) {
+ super(c);
+ }
+
+ @Override
+ public void deallocate() {
+ checkCuda(cudaFree(this));
+ setNull();
+ }
+ }
+
+ static class HostDeallocator extends DataCache implements Pointer.Deallocator {
+ HostDeallocator(DataCache c) {
+ super(c);
+ }
+
+ @Override
+ public void deallocate() {
+ checkCuda(cudaFreeHost(this));
+ setNull();
+ }
+ }
+
+ public DataCache() {}
+
+ public DataCache(long size) {
+ position = 0;
+ limit = capacity = size;
+ int error = cudaMalloc(this, size);
+ if (error != cudaSuccess) {
+ log.warn("Cannot allocate " + size + " bytes of device memory (CUDA error = " + error
+ + "), proceeding with host memory");
+ checkCuda(cudaMallocHost(this, size));
+ deallocator(new HostDeallocator(this));
+ } else {
+ deallocator(new Deallocator(this));
+ }
+ }
+
+ public DataCache(DataCache c) {
+ super(c);
+ }
+ }
+
+ protected static class TensorArray extends PointerPointer {
+
+ static class Deallocator extends TensorArray implements Pointer.Deallocator {
+ Pointer owner;
+
+ Deallocator(TensorArray a, Pointer owner) {
+ this.address = a.address;
+ this.capacity = a.capacity;
+ this.owner = owner;
+ }
+
+ @Override
+ public void deallocate() {
+ for (int i = 0; !isNull() && i < capacity; i++) {
+ cudnnTensorStruct t = this.get(cudnnTensorStruct.class, i);
+ checkCudnn(cudnnDestroyTensorDescriptor(t));
+ }
+ if (owner != null) {
+ owner.deallocate();
+ owner = null;
+ }
+ setNull();
+ }
+ }
+
+ public TensorArray() {}
+
+ public TensorArray(long size) {
+ PointerPointer p = new PointerPointer(size);
+ p.deallocate(false);
+ this.address = p.address();
+ this.limit = p.limit();
+ this.capacity = p.capacity();
+
+ cudnnTensorStruct t = new cudnnTensorStruct();
+ for (int i = 0; i < capacity; i++) {
+ checkCudnn(cudnnCreateTensorDescriptor(t));
+ this.put(i, t);
+ }
+ deallocator(new Deallocator(this, p));
+ }
+
+ public TensorArray(TensorArray a) {
+ super(a);
+ }
+ }
+
+ protected final DataType nd4jDataType;
+ protected final int dataType;
+ protected final int dataTypeSize;
+ // both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta
+ protected final Pointer alpha;
+ protected final Pointer beta;
+ protected SizeTPointer sizeInBytes = new SizeTPointer(1);
+
+ public BaseCudnnHelper(@NonNull DataType dataType){
+ this.nd4jDataType = dataType;
+ this.dataType = dataType == DataType.DOUBLE ? CUDNN_DATA_DOUBLE
+ : dataType == DataType.FLOAT ? CUDNN_DATA_FLOAT : CUDNN_DATA_HALF;
+ this.dataTypeSize = dataType == DataType.DOUBLE ? 8 : dataType == DataType.FLOAT ? 4 : 2;
+ // both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta
+ this.alpha = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(1.0) : new FloatPointer(1.0f);
+ this.beta = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(0.0) : new FloatPointer(0.0f);
+ }
+
+ public static int toCudnnDataType(DataType type){
+ switch (type){
+ case DOUBLE:
+ return CUDNN_DATA_DOUBLE;
+ case FLOAT:
+ return CUDNN_DATA_FLOAT;
+ case INT:
+ return CUDNN_DATA_INT32;
+ case HALF:
+ return CUDNN_DATA_HALF;
+ default:
+ throw new RuntimeException("Cannot convert type: " + type);
+ }
+ }
+
+ public boolean checkSupported() {
+ // add general checks here, if any
+ return true;
+ }
+
+
+ /**
+ * From CuDNN documentation -
+ * "Tensors are restricted to having at least 4 dimensions... When working with lower dimensional data, it is
+ * recommended that the user create a 4Dtensor, and set the size along unused dimensions to 1."
+ *
+ * This method implements that - basically appends 1s to the end (shape or stride) to make it length 4,
+ * or leaves it unmodified if the length is already 4 or more.
+ * This method can be used for both shape and strides
+ *
+ * @param shapeOrStrides
+ * @return
+ */
+ protected static int[] adaptForTensorDescr(int[] shapeOrStrides){
+ if(shapeOrStrides.length >= 4)
+ return shapeOrStrides;
+ int[] out = new int[4];
+ int i=0;
+ for(; i backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel,
+ int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn,
+ AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo,
+ ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
+
+ //AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
+ // correctly on NHWC data, even after updating all descriptors, tensor format, etc.
+ //Therefore: all computation here is done in NCHW format only
+ //As of a future (next?) release we'll likely switch to C++ for cuDNN support
+ boolean origNHWC = false;
+ if(format == CNN2DFormat.NHWC){
+ input = input.permute(0,3,1,2); //NHWC to NCHW
+ delta = delta.permute(0,3,1,2);
+ origNHWC = true;
+ }
+
+ int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
+
+ int code;
+
+ val miniBatch = input.size(0);
+ val outDepth = weights.size(0);
+ val inDepth = weights.size(1);
+ val kH = weights.size(2);
+ val kW = weights.size(3);
+
+ CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
+ input = args.getInput();
+ val inH = input.size(2);
+ val inW = input.size(3);
+ val srcStride = input.stride();
+ val outSize = args.getOutSize();
+ val outH = outSize[0];
+ val outW = outSize[1];
+
+ if (!Shape.strideDescendingCAscendingF(delta)) {
+ // apparently not supported by cuDNN
+ delta = delta.dup();
+ }
+
+ val deltaStride = delta.stride();
+ int[] algo1 = new int[1];
+ int[] algo2 = new int[1];
+
+
+ if (Nd4j.getExecutioner() instanceof GridExecutioner)
+ ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
+
+ code = cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth,(int) inH, (int) inW,
+ (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]);
+ checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+ code = cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) outDepth, (int) outH, (int) outW,
+ (int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]);
+ checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+ code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
+ dilation[1], CUDNN_CROSS_CORRELATION, dataType);
+ checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+ code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW);
+ checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ if (mode == AlgoMode.USER_SPECIFIED && bwdFilterAlgo != null && bwdDataAlgo != null) {
+ switch (bwdFilterAlgo) {
+ case ALGO_0:
+ algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
+ break;
+ case ALGO_1:
+ algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
+ break;
+ case FFT:
+ algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT;
+ break;
+ case ALGO_3:
+ algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3;
+ break;
+ case WINOGRAD:
+ algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD;
+ break;
+ case WINOGRAD_NONFUSED:
+ algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED;
+ break;
+ case FFT_TILING:
+ algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING;
+ break;
+ case COUNT:
+ algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
+ break;
+ default:
+ throw new IllegalArgumentException("Unknown BwdFilterAlgo: " + bwdFilterAlgo);
+ }
+
+ switch (bwdDataAlgo) {
+ case ALGO_0:
+ algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
+ break;
+ case ALGO_1:
+ algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
+ break;
+ case FFT:
+ algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT;
+ break;
+ case FFT_TILING:
+ algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING;
+ break;
+ case WINOGRAD:
+ algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD;
+ break;
+ case WINOGRAD_NONFUSED:
+ algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED;
+ break;
+ case COUNT:
+ algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
+ break;
+ default:
+ throw new IllegalArgumentException("Unknown BwdDataAlgo: " + bwdDataAlgo);
+ }
+ } else {
+ /*
+ code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
+ cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
+ mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
+ : CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+ 0, algo1);
+ */
+ val fa = new cudnnConvolutionBwdFilterAlgoPerf_t();
+ val counts = new int[1];
+ code = cudnnFindConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
+ cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, 1, counts, fa);
+ algo1[0] = fa.algo();
+
+ checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ /*
+ code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
+ cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
+ mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
+ : CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+ 0, algo2);
+ */
+
+ val da = new cudnnConvolutionBwdDataAlgoPerf_t();
+ code = cudnnFindConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
+ cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, 1, counts, da);
+
+ algo2[0] = da.algo();
+ checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+ }
+
+ if(log.isTraceEnabled()){
+ BwdFilterAlgo fa = BwdFilterAlgo.values()[algo1[0]];
+ BwdDataAlgo da = BwdDataAlgo.values()[algo2[0]];
+ log.trace("CudnnConvolutionHelper backward algorithm selection: mode {}, filter algorithm {}, data algorithm {}", mode, fa, da);
+ }
+
+ INDArray epsNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, weights.dataType(), new long[] {(int) miniBatch,(int) inDepth, (int) inH, (int) inW}, 'c');
+
+ val dstStride = epsNext.stride();
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView,
+ biasGradView, delta, epsNext);
+ Pointer srcData = allocator.getPointer(input, context);
+ Pointer filterData = allocator.getPointer(weights, context);
+ Pointer filterGradData = allocator.getPointer(weightGradView, context);
+ Pointer biasGradData = allocator.getPointer(biasGradView, context);
+ Pointer deltaData = allocator.getPointer(delta, context);
+ Pointer dstData = allocator.getPointer(epsNext, context);
+
+ code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
+ checkCudnn(false, "cudnnSetStream", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
+ (int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]);
+ checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
+ cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0],
+ sizeInBytes);
+ checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ long sizeInBytes1 = sizeInBytes.get(0);
+ code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc,
+ cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0],
+ sizeInBytes);
+ checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
+ long sizeInBytes2 = sizeInBytes.get(0);
+ if (workSpace == null || sizeInBytes1 > workSpace.capacity() || sizeInBytes2 > workSpace.capacity()) {
+ long newSize = Math.max(sizeInBytes1, sizeInBytes2);
+ if(log.isTraceEnabled()){
+ if(workSpace == null){
+ log.trace("CudnnConvolutionHelper backpropGradient: Allocating initial workspace of size {} ({})", newSize,
+ BinaryByteUnit.format(newSize, "#.00"));
+ } else {
+ log.trace("CudnnConvolutionHelper backpropGradient: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})",
+ workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"),
+ newSize, BinaryByteUnit.format(newSize, "#.00"));
+ }
+ }
+ if(workSpace != null)
+ workSpace.deallocate();
+ workSpace = new DataCache(newSize);
+ workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
+ }
+
+ code = cudnnSetTensor4dDescriptor(cudnnContext.biasTensorDesc, TENSOR_FORMAT, dataType, 1, (int) outDepth, 1, 1);
+ checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta,
+ cudnnContext.biasTensorDesc, biasGradData);
+ checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
+ cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace,
+ workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData);
+ checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData,
+ cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace,
+ workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
+ checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView,
+ delta, epsNext);
+
+ Gradient retGradient = new DefaultGradient();
+ retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView);
+ retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c');
+
+ if (CudaEnvironment.getInstance().getConfiguration().isDebug())
+ context.syncOldStream();
+
+ //Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
+ // we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
+ if(args.isManualPadBottom() || args.isManualPadRight()) {
+ epsNext = epsNext.get(all(), all(),
+ interval(0, epsNext.size(2) - (args.isManualPadBottom() ? 1 : 0)),
+ interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0)));
+ }
+
+ if(origNHWC){
+ epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
+ }
+
+ return new Pair<>(retGradient, epsNext);
+ }
+
+ @Override
+ public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad,
+ AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format,
+ LayerWorkspaceMgr workspaceMgr) {
+
+ //AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
+ // correctly on NHWC data, even after updating all descriptors, tensor format, etc.
+ //Therefore: all computation here is done in NCHW format only
+ //As of a future (next?) release we'll likely switch to C++ for cuDNN support
+ boolean origNHWC = false;
+ if(format == CNN2DFormat.NHWC){
+ input = input.permute(0,3,1,2); //NHWC to NCHW
+ origNHWC = true;
+ }
+
+ int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
+
+ int code;
+
+ val miniBatch = input.size(0);
+ val outDepth = weights.size(0);
+ val inDepth = weights.size(1);
+ val kH = weights.size(2);
+ val kW = weights.size(3);
+
+ CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
+ input = args.getInput();
+ val inH = input.size(2);
+ val inW = input.size(3);
+ val srcStride = input.stride();
+ val outSize = args.getOutSize();
+
+ if (Nd4j.getExecutioner() instanceof GridExecutioner)
+ ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
+
+ INDArray z = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, weights.dataType(), new long[] {(int) miniBatch, (int) outDepth, outSize[0], outSize[1]});
+
+ code = cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
+ (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]);
+ checkCudnn(true, "cudnnSetTensor4dDescriptorEx", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
+
+ code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW);
+ checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
+
+ code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
+ dilation[1], CUDNN_CROSS_CORRELATION, dataType);
+ checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
+
+
+ // find dimension of convolution output
+ // checkCudnn(cudnnGetConvolution2dForwardOutputDim(cudnnContext.convDesc, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, n, c, h, w));
+ // INDArray z = Nd4j.createUninitialized(new int[]{n[0],c[0],h[0],w[0]},'c');
+
+
+ int[] algo = new int[1];
+ val dstStride = z.stride();
+ code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) outDepth, (int) outSize[0],
+ (int) outSize[1], (int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]);
+ checkCudnn(true, "cudnnSetTensor4dDescriptorEx", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
+
+ if (mode == AlgoMode.USER_SPECIFIED && fwdAlgo != null) {
+ switch (fwdAlgo) {
+ case IMPLICIT_GEMM:
+ algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+ break;
+ case IMPLICIT_PRECOMP_GEMM:
+ algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
+ break;
+ case GEMM:
+ algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_GEMM;
+ break;
+ case DIRECT:
+ algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_DIRECT;
+ break;
+ case FFT:
+ algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_FFT;
+ break;
+ case FFT_TILING:
+ algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING;
+ break;
+ case WINOGRAD:
+ algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD;
+ break;
+ case WINOGRAD_NONFUSED:
+ algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED;
+ break;
+ case COUNT:
+ algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
+ break;
+ default:
+ throw new IllegalArgumentException("Unknown FwdAlgo: " + fwdAlgo);
+ }
+ } else {
+ /*
+ code = cudnnGetConvolutionForwardAlgorithm_v7(cudnnContext, cudnnContext.srcTensorDesc,
+ cudnnContext.filterDesc, cudnnContext.convDesc,
+ cudnnContext.dstTensorDesc, mode == AlgoMode.NO_WORKSPACE
+ ? CUDNN_CONVOLUTION_FWD_ : CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+ 0, algo);
+ */
+
+ val cdf = new cudnnConvolutionFwdAlgoPerf_t();
+ val count = new int[1];
+ code = cudnnFindConvolutionForwardAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, 1, count, cdf);
+
+ if(code != CUDNN_STATUS_SUCCESS){
+ //If CuDNN can't infer algorithm - try IMPLICIT_GEMM
+ //Why this specifically? According to the docs, it seems to have the least number of restrictions
+ // to things like dilation
+
+ OneTimeLogger.warn(log, "Error getting CuDNN forward algorithm - falling back on IMPLICIT_GEMM");
+ mode = AlgoMode.USER_SPECIFIED;
+ fwdAlgo = FwdAlgo.IMPLICIT_GEMM;
+ algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+ }
+
+ algo[0] = cdf.algo();
+ }
+
+ if(log.isTraceEnabled()){
+ FwdAlgo a = FwdAlgo.values()[algo[0]];
+ log.trace("CudnnConvolutionHelper forward algorithm selection: mode {}, algorithm {}", mode, a);
+ }
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareAction(z, input, weights, bias);
+ Pointer srcData = allocator.getPointer(input, context);
+ Pointer filterData = allocator.getPointer(weights, context);
+ Pointer biasData = allocator.getPointer(bias, context);
+ Pointer dstData = allocator.getPointer(z, context);
+
+ code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
+ checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
+
+ code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
+ cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0],
+ sizeInBytes);
+ checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
+
+ DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
+ if (workSpace == null || sizeInBytes.get(0) > workSpace.capacity()) {
+ if(log.isTraceEnabled()){
+ if(workSpace == null){
+ log.trace("CudnnConvolutionHelper preOutput: allocating initial workspace of size {} ({})",
+ sizeInBytes.get(), BinaryByteUnit.format(sizeInBytes.get(), "#.00"));
+ } else {
+ log.trace("CudnnConvolutionHelper preOutput: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})",
+ workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"),
+ sizeInBytes.get(), BinaryByteUnit.format(sizeInBytes.get(), "#.00"));
+ }
+ }
+ if(workSpace != null)
+ workSpace.deallocate();
+ workSpace = new DataCache(sizeInBytes.get(0));
+ workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
+ }
+ code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
+ cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace,
+ workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
+ checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
+
+
+ code = cudnnSetTensor4dDescriptor(cudnnContext.biasTensorDesc, TENSOR_FORMAT, dataType, 1, (int) outDepth, 1, 1);
+ checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
+
+ code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha,
+ cudnnContext.dstTensorDesc, dstData);
+ checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
+
+ allocator.registerAction(context, z, input, weights, bias);
+
+ if (CudaEnvironment.getInstance().getConfiguration().isDebug())
+ context.syncOldStream();
+
+ if(origNHWC){
+ z = z.permute(0,2,3,1); //NCHW to NHWC
+ }
+
+ return z;
+ }
+
+ private void checkCudnn(boolean forward, String step, int code, INDArray input, INDArray weights, INDArray bias, INDArray delta,
+ int[] kernel, int[] strides, int[] pad,
+ AlgoMode mode, FwdAlgo fwdAlgo, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, int[] dilation) {
+
+ if (code != CUDNN_STATUS_SUCCESS) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("CuDNN error = ").append(code).append(": ").append(cudnnGetErrorString(code).getString())
+ .append(" during ")
+ .append(forward ? "forward pass" : "backward pass")
+ .append(" - step ").append(step)
+ .append(": inputShape=").append(Arrays.toString(input.shape()))
+ .append(", weightsShape=").append(Arrays.toString(weights.shape()))
+ .append(", biasShape=").append(bias == null ? null : Arrays.toString(bias.shape()));
+ if (!forward) {
+ sb.append(", gradientShape=").append(Arrays.toString(delta.shape()));
+ }
+ sb.append(", kernel=").append(Arrays.toString(kernel))
+ .append(", stride=").append(Arrays.toString(strides))
+ .append(", padding=").append(Arrays.toString(pad))
+ .append(", dilation=").append(Arrays.toString(dilation))
+ .append(", AlgoMode=").append(mode);
+ if (forward) {
+ sb.append(", fwdAlgo=").append(fwdAlgo);
+ } else {
+ sb.append(", bwdFilterAlgo=").append(bwdFilterAlgo)
+ .append(", bwdDataAlgo=").append(bwdDataAlgo);
+ }
+ sb.append(", convolutionMode=").append(convolutionMode);
+
+ throw new RuntimeException(sb.toString());
+ }
+ }
+
+ @Override
+ public INDArray activate(INDArray z, IActivation afn, boolean training) {
+ if (Nd4j.getExecutioner() instanceof GridExecutioner)
+ ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
+
+ INDArray activation = z;
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareAction(z);
+ Pointer dstData = allocator.getPointer(z, context);
+
+ checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
+ switch (afn.toString()) {
+ case "identity":
+ break;
+ case "sigmoid":
+ checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID,
+ CUDNN_PROPAGATE_NAN, 0));
+ checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
+ cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
+ break;
+ case "relu":
+ checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU,
+ CUDNN_PROPAGATE_NAN, 0));
+ checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
+ cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
+ break;
+ case "tanh":
+ checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH,
+ CUDNN_PROPAGATE_NAN, 0));
+ checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
+ cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
+ break;
+ case "softmax":
+ checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
+ cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
+ break;
+ case "logsoftmax":
+ checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
+ cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
+ break;
+ default:
+ activation = null;
+ }
+
+ allocator.registerAction(context, activation);
+
+ if (CudaEnvironment.getInstance().getConfiguration().isDebug())
+ context.syncOldStream();
+
+ return activation;
+ }
+
+ /**
+ * @param poolingType Used when preparing data for subsampling layers ONLY. Null for convolution layers
+ * @return
+ */
+ public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation,
+ ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat format){
+ INDArray origInput = input;
+
+ //Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides
+ // are non-default for C order - even if they *should* be OK otherwise
+ if(input.isView() || !Shape.hasDefaultStridesForShape(input)){
+ input = input.dup('c');
+ }
+
+ boolean nchw = format == CNN2DFormat.NCHW;
+ int hIdx = nchw ? 2 : 1;
+ int wIdx = nchw ? 3 : 2;
+
+ val inH = input.size(hIdx);
+ val inW = input.size(wIdx);
+
+ boolean manualPadBottom = false;
+ boolean manualPadRight = false;
+
+ int[] outSize;
+ if (convolutionMode == ConvolutionMode.Same) {
+ outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
+ padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
+ int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
+ if(!Arrays.equals(padding, padBottomRight)){
+ /*
+ CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
+ padding) - padding can *only* be specified as the same amount for both the top/bottom, and for left/right.
+ In SAME mode padding, sometimes these are the same - but often they are not.
+ Note that when they differ, the bottom or right padding will be exactly 1 more than the top or left padding.
+ As per TF, we'll manually pad here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/conv_ops.cc#L571-L607
+ */
+ manualPadBottom = (padding[0] != padBottomRight[0]);
+ manualPadRight = (padding[1] != padBottomRight[1]);
+
+ //NCHW format
+ long[] newShape;
+ if(nchw){
+ newShape = new long[]{input.size(0), input.size(1),
+ input.size(2) + (manualPadBottom ? 1 : 0),
+ input.size(3) + (manualPadRight ? 1 : 0)};
+ } else {
+ newShape = new long[]{input.size(0),
+ input.size(1) + (manualPadBottom ? 1 : 0),
+ input.size(2) + (manualPadRight ? 1 : 0),
+ input.size(3)};
+ }
+ INDArray newInput;
+ if(poolingType == null || poolingType != PoolingType.MAX){
+ newInput = Nd4j.create(input.dataType(), newShape);
+ } else {
+ //For max pooling, we don't want to include the padding in the maximum values. But, CuDNN doesn't knowm
+ // that these values are padding and hence should be excluded. Instead: We'll use -infinity so that,
+ // if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value
+ newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType());
+ }
+
+ if(nchw){
+ newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)),
+ interval(0, input.size(3))}, input);
+ } else {
+ newInput.put(new INDArrayIndex[]{all(), interval(0,input.size(1)),
+ interval(0, input.size(2)), all()}, input);
+ }
+
+ input = newInput;
+ //Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we
+ // now have the same amount of padding required for top/bottom, and left/right - which we'll let
+ // CuDNN handle
+ }
+ } else {
+ outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
+ }
+
+ return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
+ }
+
+
+ @AllArgsConstructor
+ @Data
+ public static class CudnnForwardArgs {
+ private boolean manualPadBottom;
+ private boolean manualPadRight;
+ private INDArray input;
+ private INDArray origInput;
+ private int[] padding;
+ private int[] outSize;
+ }
+
+ @Override
+ public Map helperMemoryUse() {
+ //No memory use other than shared, and the structs (which are small)
+ return Collections.emptyMap();
+ }
+
+}
\ No newline at end of file
diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/convolution/subsampling/CudnnSubsamplingHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/convolution/subsampling/CudnnSubsamplingHelper.java
new file mode 100644
index 000000000..b92810959
--- /dev/null
+++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/convolution/subsampling/CudnnSubsamplingHelper.java
@@ -0,0 +1,308 @@
+/*
+ * ******************************************************************************
+ * *
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.deeplearning4j.cuda.convolution.subsampling;
+
+import lombok.extern.slf4j.Slf4j;
+import lombok.val;
+import org.bytedeco.javacpp.Pointer;
+import org.deeplearning4j.nn.conf.CNN2DFormat;
+import org.deeplearning4j.nn.conf.ConvolutionMode;
+import org.deeplearning4j.nn.conf.layers.PoolingType;
+import org.deeplearning4j.nn.gradient.DefaultGradient;
+import org.deeplearning4j.nn.gradient.Gradient;
+import org.deeplearning4j.cuda.BaseCudnnHelper;
+import org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper;
+import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper;
+import org.nd4j.jita.allocator.Allocator;
+import org.nd4j.jita.allocator.impl.AtomicAllocator;
+import org.nd4j.jita.conf.CudaEnvironment;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
+import org.nd4j.linalg.api.shape.Shape;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.jcublas.context.CudaContext;
+import org.nd4j.common.primitives.Pair;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.deeplearning4j.nn.workspace.ArrayType;
+
+import java.util.Collections;
+import java.util.Map;
+
+import org.bytedeco.cuda.cudart.*;
+import org.bytedeco.cuda.cudnn.*;
+
+import static org.bytedeco.cuda.global.cudnn.*;
+import static org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper.getCudnnForwardArgs;
+import static org.nd4j.linalg.indexing.NDArrayIndex.all;
+import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
+
+/**
+ * cuDNN-based helper for the subsampling layer.
+ *
+ * @author saudet
+ */
+@Slf4j
+public class CudnnSubsamplingHelper extends BaseCudnnHelper implements SubsamplingHelper {
+
+ public CudnnSubsamplingHelper(DataType dataType) {
+ super(dataType);
+ }
+
+ private static class CudnnSubsamplingContext extends CudnnContext {
+
+ private static class Deallocator extends CudnnSubsamplingContext implements Pointer.Deallocator {
+ Deallocator(CudnnSubsamplingContext c) {
+ super(c);
+ }
+
+ @Override
+ public void deallocate() {
+ destroyHandles();
+ }
+ }
+
+ private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
+ deltaTensorDesc = new cudnnTensorStruct();
+ private cudnnPoolingStruct poolingDesc = new cudnnPoolingStruct();
+
+ public CudnnSubsamplingContext() {
+ createHandles();
+ deallocator(new Deallocator(this));
+ }
+
+ public CudnnSubsamplingContext(CudnnSubsamplingContext c) {
+ super(c);
+ srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
+ dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
+ deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
+ poolingDesc = new cudnnPoolingStruct(c.poolingDesc);
+ }
+
+ @Override
+ protected void createHandles() {
+ super.createHandles();
+ checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
+ checkCudnn(cudnnCreatePoolingDescriptor(poolingDesc));
+ }
+
+ @Override
+ protected void destroyHandles() {
+ checkCudnn(cudnnDestroyPoolingDescriptor(poolingDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
+ super.destroyHandles();
+ }
+ }
+
+ private CudnnSubsamplingContext cudnnContext = new CudnnSubsamplingContext();
+
+ @Override
+ public Pair backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides,
+ int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode,
+ int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
+ if(dilation[0] != 1 || dilation[1] != 1){
+ //CuDNN doesn't support dilated subsampling
+ return null;
+ }
+
+ boolean nchw = format == CNN2DFormat.NCHW;
+ int chIdx = nchw ? 1 : 3;
+ int hIdx = nchw ? 2 : 1;
+ int wIdx = nchw ? 3 : 2;
+
+ //We require the output as one of the arguments for backprop here
+ //TODO we could add cache mode support here somehow...
+ INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr);
+
+ val miniBatch = input.size(0);
+ val depth = input.size(chIdx);
+
+ CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
+ input = args.getInput();
+ val inH = input.size(hIdx);
+ val inW = input.size(wIdx);
+ val srcStride = input.stride();
+ int[] outSize = args.getOutSize();
+ int outH = outSize[0];
+ int outW = outSize[1];
+
+ //subsampling doesn't have weights and thus gradients are not calculated for this layer
+ //only scale and reshape epsilon
+ Gradient retGradient = new DefaultGradient();
+
+ //Epsilons in shape: [miniBatch, channels, outH, outW]
+ //Epsilons out shape: [miniBatch, channels, inH, inW]
+
+ int poolingMode;
+ switch (poolingType) {
+ case AVG:
+ poolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
+ break;
+ case MAX:
+ poolingMode = CUDNN_POOLING_MAX;
+ break;
+ default:
+ return null;
+ }
+
+ if (!Shape.hasDefaultStridesForShape(epsilon) || epsilon.isView()) {
+ // apparently not supported by cuDNN
+ epsilon = epsilon.dup('c');
+ }
+
+ input = input.dup();
+
+ val deltaStride = epsilon.stride();
+
+ if (Nd4j.getExecutioner() instanceof GridExecutioner)
+ ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
+
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
+ (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW,
+ (int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
+ checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
+ kernel[1], pad[0], pad[1], strides[0], strides[1]));
+
+ long[] outEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
+ INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c');
+
+ val dstStride = outEpsilon.stride();
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
+ (int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon);
+ Pointer srcData = allocator.getPointer(input, context);
+ Pointer epsData = allocator.getPointer(epsilon, context);
+ Pointer zData = allocator.getPointer(reduced, context);
+ Pointer dstData = allocator.getPointer(outEpsilon, context);
+
+ checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
+ checkCudnn(cudnnPoolingBackward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.deltaTensorDesc,
+ zData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.srcTensorDesc, srcData, beta,
+ cudnnContext.dstTensorDesc, dstData));
+
+ allocator.registerAction(context, outEpsilon, input, epsilon, reduced);
+
+ if (CudaEnvironment.getInstance().getConfiguration().isDebug())
+ context.syncOldStream();
+
+ //Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
+ // we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
+ if(args.isManualPadBottom() || args.isManualPadRight()) {
+ if(nchw){
+ outEpsilon = outEpsilon.get(all(), all(),
+ interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)),
+ interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0)));
+ } else {
+ outEpsilon = outEpsilon.get(all(),
+ interval(0, outEpsilon.size(1) - (args.isManualPadBottom() ? 1 : 0)),
+ interval(0, outEpsilon.size(2) - (args.isManualPadRight() ? 1 : 0)),
+ all());
+ }
+ }
+
+ return new Pair<>(retGradient, outEpsilon);
+ }
+
+
+ @Override
+ public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad,
+ PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
+ if(dilation[0] != 1 || dilation[1] != 1){
+ //CuDNN doesn't support dilated subsampling
+ return null;
+ }
+
+ boolean nchw = format == CNN2DFormat.NCHW;
+ int chIdx = nchw ? 1 : 3;
+ int hIdx = nchw ? 2 : 1;
+ int wIdx = nchw ? 3 : 2;
+
+ val miniBatch = input.size(0);
+ val inDepth = input.size(nchw ? 1 : 3);
+
+ CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
+ input = args.getInput();
+ val inH = input.size(nchw ? 2 : 1);
+ val inW = input.size(nchw ? 3 : 2);
+ val srcStride = input.stride();
+ val outSize = args.getOutSize();
+ int outH = outSize[0];
+ int outW = outSize[1];
+
+
+ int poolingMode;
+ switch (poolingType) {
+ case AVG:
+ poolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
+ break;
+ case MAX:
+ poolingMode = CUDNN_POOLING_MAX;
+ break;
+ default:
+ return null;
+ }
+
+ if (Nd4j.getExecutioner() instanceof GridExecutioner)
+ ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
+
+ checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
+ kernel[1], pad[0], pad[1], strides[0], strides[1]));
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
+ (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
+
+ long[] outShape = nchw ? new long[] {miniBatch, inDepth, outH, outW} : new long[] {miniBatch, outH, outW, inDepth};
+ INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c');
+
+ val dstStride = reduced.stride();
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW,
+ (int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareAction(input, reduced);
+ Pointer srcData = allocator.getPointer(input, context);
+ Pointer dstData = allocator.getPointer(reduced, context);
+
+ checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
+ checkCudnn(cudnnPoolingForward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.srcTensorDesc,
+ srcData, beta, cudnnContext.dstTensorDesc, dstData));
+
+ allocator.registerAction(context, reduced, input);
+
+ if (CudaEnvironment.getInstance().getConfiguration().isDebug())
+ context.syncOldStream();
+
+ return reduced;
+ }
+
+ @Override
+ public Map helperMemoryUse() {
+ //No persistent memory use other than the structs (which are small)
+ return Collections.emptyMap();
+ }
+
+}
diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/dropout/CudnnDropoutHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/dropout/CudnnDropoutHelper.java
new file mode 100644
index 000000000..83fd9c7f0
--- /dev/null
+++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/dropout/CudnnDropoutHelper.java
@@ -0,0 +1,245 @@
+/*
+ * ******************************************************************************
+ * *
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.deeplearning4j.cuda.dropout;
+
+import lombok.Data;
+import lombok.extern.slf4j.Slf4j;
+import com.jakewharton.byteunits.BinaryByteUnit;
+import org.bytedeco.javacpp.*;
+import org.deeplearning4j.nn.conf.dropout.DropoutHelper;
+import org.deeplearning4j.cuda.BaseCudnnHelper;
+import org.nd4j.jita.allocator.Allocator;
+import org.nd4j.jita.allocator.impl.AtomicAllocator;
+import org.nd4j.jita.conf.CudaEnvironment;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.jcublas.context.CudaContext;
+import org.nd4j.common.util.ArrayUtil;
+
+import org.bytedeco.cuda.cudart.*;
+import org.bytedeco.cuda.cudnn.*;
+
+import java.util.Collections;
+import java.util.Map;
+
+import static org.bytedeco.cuda.global.cudnn.*;
+
+/**
+ * CuDNN dropout helper
+ *
+ * Note that for repeatability between calls (for example, for gradient checks), we need to do two things:
+ * (a) set the ND4J RNG seed
+ * (b) clear the rngStates field
+ *
+ * @author Alex Black
+ */
+@Data
+@Slf4j
+public class CudnnDropoutHelper extends BaseCudnnHelper implements DropoutHelper {
+
+ private static class CudnnDropoutContext extends CudnnContext {
+
+ private static class Deallocator extends CudnnDropoutContext implements Pointer.Deallocator {
+ Deallocator(CudnnDropoutContext c) {
+ super(c);
+ }
+
+ @Override
+ public void deallocate() {
+ destroyHandles();
+ }
+ }
+
+ private cudnnTensorStruct xTensorDesc = new cudnnTensorStruct(); //Input
+ private cudnnTensorStruct dxTensorDesc = new cudnnTensorStruct(); //Grad at input
+ private cudnnTensorStruct yTensorDesc = new cudnnTensorStruct(); //Output
+ private cudnnTensorStruct dyTensorDesc = new cudnnTensorStruct(); //Grad at output
+ private cudnnDropoutStruct dropoutDesc = new cudnnDropoutStruct();
+
+ public CudnnDropoutContext() {
+ createHandles();
+ deallocator(new Deallocator(this));
+ }
+
+ public CudnnDropoutContext(CudnnDropoutContext c) {
+ super(c);
+ xTensorDesc = new cudnnTensorStruct(c.xTensorDesc);
+ dxTensorDesc = new cudnnTensorStruct(c.dxTensorDesc);
+ yTensorDesc = new cudnnTensorStruct(c.yTensorDesc);
+ dyTensorDesc = new cudnnTensorStruct(c.dyTensorDesc);
+ dropoutDesc = new cudnnDropoutStruct(c.dropoutDesc);
+ }
+
+ @Override
+ protected void createHandles() {
+ super.createHandles();
+ checkCudnn(cudnnCreateTensorDescriptor(xTensorDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(dxTensorDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(yTensorDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(dyTensorDesc));
+ checkCudnn(cudnnCreateDropoutDescriptor(dropoutDesc));
+ }
+
+ @Override
+ protected void destroyHandles() {
+ checkCudnn(cudnnDestroyTensorDescriptor(xTensorDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(dxTensorDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(yTensorDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(dyTensorDesc));
+ checkCudnn(cudnnDestroyDropoutDescriptor(dropoutDesc));
+ super.destroyHandles();
+ }
+ }
+
+ private CudnnDropoutContext cudnnContext = new CudnnDropoutContext();
+ private boolean initializedDescriptor = false;
+ private DataCache rngStates; //"Pointer to user-allocated GPU memory that will hold random number generator states."
+ private DataCache mask; //Mask: persistence between forward and backward
+ private SizeTPointer stateSizeBytesPtr;
+ private SizeTPointer reserveSizeBytesPtr;
+ private float lastInitializedP;
+
+ public CudnnDropoutHelper(DataType dataType){
+ super(dataType);
+ }
+
+ //@Override
+ public Map helperMemoryUse() {
+ return Collections.emptyMap();
+ }
+
+ @Override
+ public boolean checkSupported() {
+ return true;
+ }
+
+ @Override
+ public void applyDropout(INDArray input, INDArray resultArray, double dropoutInputRetainProb) {
+ float p = (float)(1.0 - dropoutInputRetainProb); //CuDNN uses p = probability of setting to 0. We use p = probability of retaining
+
+ //TODO int cast
+ int[] inShape = adaptForTensorDescr(ArrayUtil.toInts(input.shape()));
+ int[] inStride = adaptForTensorDescr(ArrayUtil.toInts(input.stride()));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.xTensorDesc, dataType, inShape.length, inShape, inStride));
+
+ int[] outShape = adaptForTensorDescr(ArrayUtil.toInts(resultArray.shape()));
+ int[] outStride = adaptForTensorDescr(ArrayUtil.toInts(resultArray.stride()));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.yTensorDesc, dataType, outShape.length, outShape, outStride));
+
+
+ if(stateSizeBytesPtr == null){
+ stateSizeBytesPtr = new SizeTPointer(1);
+ reserveSizeBytesPtr = new SizeTPointer(1);
+ }
+ checkCudnn(cudnnDropoutGetStatesSize(cudnnContext, stateSizeBytesPtr));
+ long rngStateSizeBytes = stateSizeBytesPtr.get();
+ checkCudnn(cudnnDropoutGetReserveSpaceSize(cudnnContext.xTensorDesc, reserveSizeBytesPtr));
+ long maskReserveSizeBytes = reserveSizeBytesPtr.get();
+
+ if(rngStates == null || rngStates.capacity() < rngStateSizeBytes){
+ if(log.isTraceEnabled()){
+ if(rngStates == null){
+ log.trace("CudnnDropoutHelper: Allocating intial RNG states workspace of size {} ({})", rngStateSizeBytes,
+ BinaryByteUnit.format(rngStateSizeBytes, "#.00"));
+ } else {
+ log.trace("CudnnDropoutHelper: Deallocating RNG states of size {} ({}), allocating new workspace of size {} ({})",
+ rngStates.capacity(), BinaryByteUnit.format(rngStates.capacity(), "#.00"),
+ rngStateSizeBytes, BinaryByteUnit.format(rngStateSizeBytes, "#.00"));
+ }
+ }
+
+ if(rngStates != null)
+ rngStates.deallocate();
+ //states = "Pointer to user-allocated GPU memory that will hold random number generator states."
+ rngStates = new DataCache(rngStateSizeBytes);
+ initializedDescriptor = false;
+ }
+ if(mask == null || mask.capacity() < maskReserveSizeBytes){
+ if(log.isTraceEnabled()){
+ if(mask == null){
+ log.trace("CudnnDropoutHelper: Allocating intial mask array of size {} ({})", maskReserveSizeBytes,
+ BinaryByteUnit.format(maskReserveSizeBytes, "#.00"));
+ } else {
+ log.trace("CudnnDropoutHelper: Deallocating mask array of size {} ({}), allocating new mask array of size {} ({})",
+ mask.capacity(), BinaryByteUnit.format(mask.capacity(), "#.00"),
+ maskReserveSizeBytes, BinaryByteUnit.format(maskReserveSizeBytes, "#.00"));
+ }
+ }
+
+ if(mask != null)
+ mask.deallocate();
+ //mask = "Pointer to user-allocated GPU memory used by this function. It is expected
+ //that contents of reserveSpace doe not change between cudnnDropoutForward and
+ //cudnnDropoutBackward calls."
+ mask = new DataCache(maskReserveSizeBytes);
+ }
+
+ //Dropout descriptor: (re)initialize if required
+ if(!initializedDescriptor || p != lastInitializedP) {
+ if(log.isTraceEnabled()){
+ log.trace("CudnnDropoutHelper: (re)initializing dropout descriptor");
+ }
+ //NOTE: cudnnSetDropoutDescriptor has some internal computation/initialization, and hence is expensive to
+ // call - so we want to call this as infrequently as possible, and cache the result
+ long seed = Nd4j.getRandom().nextLong();
+ lastInitializedP = p;
+ checkCudnn(cudnnSetDropoutDescriptor(cudnnContext.dropoutDesc, cudnnContext, p, rngStates, rngStates.capacity(), seed));
+ initializedDescriptor = true;
+ }
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareAction(input, resultArray);
+ Pointer xPtr = allocator.getPointer(input, context);
+ Pointer yPtr = allocator.getPointer(resultArray, context);
+
+ checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
+ checkCudnn(cudnnDropoutForward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.xTensorDesc, xPtr,
+ cudnnContext.yTensorDesc, yPtr, mask, mask.capacity()));
+
+ allocator.registerAction(context, input, resultArray);
+ if (CudaEnvironment.getInstance().getConfiguration().isDebug())
+ context.syncOldStream();
+ }
+
+ @Override
+ public void backprop(INDArray gradAtOutput, INDArray gradAtInput) {
+ int[] gradAtOutShape = adaptForTensorDescr(ArrayUtil.toInts(gradAtOutput.shape()));
+ int[] gradAtOutStride = adaptForTensorDescr(ArrayUtil.toInts(gradAtOutput.stride()));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dyTensorDesc, dataType, gradAtOutShape.length, gradAtOutShape, gradAtOutStride));
+
+ int[] gradAtInShape = adaptForTensorDescr(ArrayUtil.toInts(gradAtInput.shape()));
+ int[] gradAtInStride = adaptForTensorDescr(ArrayUtil.toInts(gradAtInput.stride()));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dxTensorDesc, dataType, gradAtInShape.length, gradAtInShape, gradAtInStride));
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareAction(gradAtOutput, gradAtInput);
+ Pointer dyPtr = allocator.getPointer(gradAtOutput, context);
+ Pointer dxPtr = allocator.getPointer(gradAtInput, context);
+
+ checkCudnn(cudnnDropoutBackward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.dyTensorDesc, dyPtr,
+ cudnnContext.dxTensorDesc, dxPtr, mask, mask.capacity()));
+
+ allocator.registerAction(context, gradAtOutput, gradAtInput);
+ if (CudaEnvironment.getInstance().getConfiguration().isDebug())
+ context.syncOldStream();
+ }
+}
diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnBatchNormalizationHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnBatchNormalizationHelper.java
new file mode 100644
index 000000000..fea813aa0
--- /dev/null
+++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnBatchNormalizationHelper.java
@@ -0,0 +1,384 @@
+/*
+ * ******************************************************************************
+ * *
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.deeplearning4j.cuda.normalization;
+
+import lombok.extern.slf4j.Slf4j;
+import lombok.val;
+import org.bytedeco.javacpp.Pointer;
+import org.deeplearning4j.nn.conf.CNN2DFormat;
+import org.deeplearning4j.nn.gradient.DefaultGradient;
+import org.deeplearning4j.nn.gradient.Gradient;
+import org.deeplearning4j.cuda.BaseCudnnHelper;
+import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
+import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
+import org.deeplearning4j.nn.workspace.ArrayType;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.jita.allocator.Allocator;
+import org.nd4j.jita.allocator.impl.AtomicAllocator;
+import org.nd4j.jita.conf.CudaEnvironment;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.memory.MemoryWorkspace;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
+import org.nd4j.linalg.api.shape.Shape;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.jcublas.context.CudaContext;
+import org.nd4j.common.primitives.Pair;
+import org.nd4j.common.util.ArrayUtil;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.bytedeco.cuda.cudart.*;
+import org.bytedeco.cuda.cudnn.*;
+
+import static org.bytedeco.cuda.global.cudnn.*;
+
+/**
+ * cuDNN-based helper for the batch normalization layer.
+ *
+ * @author saudet
+ */
+@Slf4j
+public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements BatchNormalizationHelper {
+
+ public CudnnBatchNormalizationHelper(DataType dataType) {
+ super(dataType);
+ }
+
+ private static class CudnnBatchNormalizationContext extends CudnnContext {
+
+ private static class Deallocator extends CudnnBatchNormalizationContext implements Pointer.Deallocator {
+ Deallocator(CudnnBatchNormalizationContext c) {
+ super(c);
+ }
+
+ @Override
+ public void deallocate() {
+ destroyHandles();
+ }
+ }
+
+ private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
+ deltaTensorDesc = new cudnnTensorStruct(), gammaBetaTensorDesc = new cudnnTensorStruct();
+
+ public CudnnBatchNormalizationContext() {
+ createHandles();
+ deallocator(new Deallocator(this));
+ }
+
+ public CudnnBatchNormalizationContext(CudnnBatchNormalizationContext c) {
+ super(c);
+ srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
+ dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
+ deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
+ gammaBetaTensorDesc = new cudnnTensorStruct(c.gammaBetaTensorDesc);
+ }
+
+ @Override
+ protected void createHandles() {
+ super.createHandles();
+ checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(gammaBetaTensorDesc));
+ }
+
+ @Override
+ protected void destroyHandles() {
+ checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(gammaBetaTensorDesc));
+ super.destroyHandles();
+ }
+ }
+
+ protected final int batchNormMode = CUDNN_BATCHNORM_SPATIAL; // would need to increase rank of gamma and beta for CUDNN_BATCHNORM_PER_ACTIVATION
+
+ private CudnnBatchNormalizationContext cudnnContext = new CudnnBatchNormalizationContext();
+ private INDArray meanCache;
+ private INDArray varCache;
+ private double eps;
+
+ public boolean checkSupported(double eps, boolean isFixedGammaBeta) {
+ boolean supported = checkSupported();
+ if (eps < CUDNN_BN_MIN_EPSILON) {
+ supported = false;
+ log.warn("Not supported: eps < CUDNN_BN_MIN_EPSILON (" + eps + " < " + CUDNN_BN_MIN_EPSILON + ")");
+ }
+ return supported;
+ }
+
+ @Override
+ public Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
+ INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) {
+
+ boolean nchw = format == CNN2DFormat.NCHW;
+
+ this.eps = eps;
+
+ int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
+ int chIdx = nchw ? 1 : 3;
+ int hIdx = nchw ? 2 : 1;
+ int wIdx = nchw ? 3 : 2;
+
+ val miniBatch = (int) input.size(0);
+ val depth = (int) input.size(chIdx);
+ val inH = (int) input.size(hIdx);
+ val inW = (int) input.size(wIdx);
+
+ final boolean isHalf = (input.dataType() == DataType.HALF);
+ INDArray gammaOrig = null;
+ INDArray dGammaViewOrig = null;
+ INDArray dBetaViewOrig = null;
+ if(isHalf) { //Convert FP16 to FP32 if required (CuDNN BN doesn't support FP16 for these params, only for input/output)
+ gammaOrig = gamma;
+ dGammaViewOrig = dGammaView;
+ dBetaViewOrig = dBetaView;
+ /*
+ From CuDNN docs: bnScale, resultBnScaleDiff, resultBnBiasDiff, savedMean, savedInvVariance
+ "Note: The data type of this tensor descriptor must be 'float' for FP16 and FP32 input tensors, and 'double'
+ for FP64 input tensors."
+ >> Last 2 are the meanCache and varCache; first 3 are below
+ */
+ gamma = gamma.castTo(DataType.FLOAT);
+ dGammaView = dGammaView.castTo(DataType.FLOAT);
+ dBetaView = dBetaView.castTo(DataType.FLOAT);
+ }
+
+ Gradient retGradient = new DefaultGradient();
+
+ if (!Shape.hasDefaultStridesForShape(epsilon)) {
+ // apparently not supported by cuDNN
+ epsilon = epsilon.dup('c');
+ }
+
+ val srcStride = ArrayUtil.toInts(input.stride());
+ val deltaStride = ArrayUtil.toInts(epsilon.stride());
+
+ if (Nd4j.getExecutioner() instanceof GridExecutioner)
+ ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
+
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
+ (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
+ (int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
+
+ long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
+ INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c');
+ val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
+
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
+ dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
+ checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
+ (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma,
+ dGammaView, dBetaView);
+ Pointer srcData = allocator.getPointer(input, context);
+ Pointer epsData = allocator.getPointer(epsilon, context);
+ Pointer dstData = allocator.getPointer(nextEpsilon, context);
+ Pointer gammaData = allocator.getPointer(gamma, context);
+ Pointer dGammaData = allocator.getPointer(dGammaView, context);
+ Pointer dBetaData = allocator.getPointer(dBetaView, context);
+ Pointer meanCacheData = allocator.getPointer(meanCache, context);
+ Pointer varCacheData = allocator.getPointer(varCache, context);
+
+ checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
+ checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha,
+ cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
+ cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
+ dBetaData, eps, meanCacheData, varCacheData));
+
+ allocator.getFlowController().registerActionAllWrite(context, input, epsilon, nextEpsilon, gamma, dGammaView,
+ dBetaView);
+
+ retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
+ retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
+
+ context.syncOldStream();
+
+ //Convert back and assign, if required:
+ if(isHalf){
+ gammaOrig.assign(gamma.castTo(DataType.HALF));
+ dGammaViewOrig.assign(dGammaView.castTo(DataType.HALF));
+ dBetaViewOrig.assign(dBetaView.castTo(DataType.HALF));
+ }
+
+ return new Pair<>(retGradient, nextEpsilon);
+ }
+
+
+ @Override
+ public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
+ INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
+ boolean nchw = format == CNN2DFormat.NCHW;
+ int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
+ int chIdx = nchw ? 1 : 3;
+ int hIdx = nchw ? 2 : 1;
+ int wIdx = nchw ? 3 : 2;
+
+ this.eps = eps;
+ final boolean isHalf = (x.dataType() == DataType.FLOAT16);
+ INDArray origGamma = gamma;
+ INDArray origBeta = beta;
+ INDArray origMean = mean;
+ INDArray origVar = var;
+ if(isHalf) {
+ gamma = gamma.castTo(DataType.FLOAT);
+ beta = beta.castTo(DataType.FLOAT);
+ mean = mean.castTo(DataType.FLOAT);
+ var = var.castTo(DataType.FLOAT);
+ }
+
+ //Notation difference between CuDNN and our implementation:
+ //Us: runningMean = (1-decay) * batchMean + decay * runningMean
+ //CuDNN: runningMean = decay * batchMean + (1-decay) * runningMean
+ //i.e., "decay" has a different meaning...
+ //Disable in-place updating of running mean/variance, so that all parameter changes are done via the update/gradient
+ // vector. This is necessary for BatchNormalization to be safe to use in distributed gradient sharing settings
+ decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled"
+
+ val miniBatch = (int) x.size(0);
+ val inDepth = (int) x.size(chIdx);
+ val inH = (int) x.size(hIdx);
+ val inW = (int) x.size(wIdx);
+
+ val srcStride = ArrayUtil.toInts(x.stride());
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
+ srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx]));
+
+ long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth};
+ INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c');
+
+ val dstStride = ArrayUtil.toInts(activations.stride());
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
+ dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
+
+ checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0],
+ (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context =
+ allocator.getFlowController().prepareActionAllWrite(x, activations, gamma, beta, mean, var);
+ Pointer srcData = allocator.getPointer(x, context);
+ Pointer dstData = allocator.getPointer(activations, context);
+ Pointer gammaData = allocator.getPointer(gamma, context);
+ Pointer betaData = allocator.getPointer(beta, context);
+ Pointer meanData = allocator.getPointer(mean, context);
+ Pointer varData = allocator.getPointer(var, context);
+
+ if (Nd4j.getExecutioner() instanceof GridExecutioner)
+ ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
+
+ checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
+ if (training) {
+ if(meanCache == null || meanCache.length() < mean.length()){
+ try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
+ meanCache = Nd4j.createUninitialized(x.dataType(), mean.length());
+ }
+ if(x.dataType() == DataType.HALF){
+ try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
+ meanCache = meanCache.castTo(DataType.FLOAT);
+ }
+ }
+ }
+ if(varCache == null || varCache.length() < mean.length()){
+ try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
+ varCache = Nd4j.createUninitialized(x.dataType(), mean.length());
+ }
+ if(nd4jDataType == DataType.HALF){
+ try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
+ varCache = varCache.castTo(DataType.FLOAT);
+ }
+ }
+ }
+ Pointer meanCacheData = allocator.getPointer(meanCache, context);
+ Pointer varCacheData = allocator.getPointer(varCache, context);
+
+ checkCudnn(cudnnBatchNormalizationForwardTraining(cudnnContext, batchNormMode, this.alpha, this.beta,
+ cudnnContext.srcTensorDesc, srcData, cudnnContext.dstTensorDesc, dstData,
+ cudnnContext.gammaBetaTensorDesc, gammaData, betaData, decay, meanData, varData, eps,
+ meanCacheData, varCacheData));
+ } else {
+ checkCudnn(cudnnBatchNormalizationForwardInference(cudnnContext, batchNormMode, this.alpha, this.beta,
+ cudnnContext.srcTensorDesc, srcData, cudnnContext.dstTensorDesc, dstData,
+ cudnnContext.gammaBetaTensorDesc, gammaData, betaData, meanData, varData, eps));
+ }
+
+ allocator.getFlowController().registerActionAllWrite(context, x, activations, gamma, beta, mean, var);
+
+ if (CudaEnvironment.getInstance().getConfiguration().isDebug())
+ context.syncOldStream();
+
+ context.syncOldStream();
+ if(training) {
+ AtomicAllocator.getInstance().getAllocationPoint(meanCache).tickDeviceWrite();
+ AtomicAllocator.getInstance().getAllocationPoint(varCache).tickDeviceWrite();
+ }
+
+ if(training && isHalf){
+ //Update the running mean and variance arrays; also gamma/beta
+ origMean.assign(mean.castTo(DataType.HALF));
+ origVar.assign(var.castTo(DataType.HALF));
+ origGamma.assign(gamma.castTo(DataType.HALF));
+ origBeta.assign(beta.castTo(DataType.HALF));
+ }
+
+ return activations;
+ }
+
+ @Override
+ public INDArray getMeanCache(DataType dataType) {
+ if(dataType == DataType.HALF){
+ //Buffer is FP32
+ return meanCache.castTo(DataType.HALF);
+ }
+ return meanCache;
+ }
+
+ @Override
+ public INDArray getVarCache(DataType dataType) {
+ INDArray ret;
+ if(dataType == DataType.HALF){
+ INDArray vc = varCache.castTo(DataType.HALF);
+ ret = vc.mul(vc).rdivi(1.0).subi(eps);
+ } else {
+ ret = varCache.mul(varCache).rdivi(1.0).subi(eps);
+ }
+ if(dataType == DataType.HALF){
+ //Buffer is FP32
+ return ret.castTo(DataType.HALF);
+ }
+ return ret;
+ }
+
+
+ @Override
+ public Map helperMemoryUse() {
+ Map memUse = new HashMap<>();
+ memUse.put("meanCache", meanCache == null ? 0 : meanCache.length() * meanCache.data().getElementSize());
+ memUse.put("varCache", varCache == null ? 0 : varCache.length() * varCache.data().getElementSize());
+ return memUse;
+ }
+}
diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnLocalResponseNormalizationHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnLocalResponseNormalizationHelper.java
new file mode 100644
index 000000000..e0257a3ec
--- /dev/null
+++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnLocalResponseNormalizationHelper.java
@@ -0,0 +1,240 @@
+/*
+ * ******************************************************************************
+ * *
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.deeplearning4j.cuda.normalization;
+
+import lombok.extern.slf4j.Slf4j;
+import lombok.val;
+import org.bytedeco.javacpp.Pointer;
+import org.deeplearning4j.nn.gradient.DefaultGradient;
+import org.deeplearning4j.nn.gradient.Gradient;
+import org.deeplearning4j.cuda.BaseCudnnHelper;
+import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper;
+import org.nd4j.jita.allocator.Allocator;
+import org.nd4j.jita.allocator.impl.AtomicAllocator;
+import org.nd4j.jita.conf.CudaEnvironment;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
+import org.nd4j.linalg.api.shape.Shape;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.jcublas.context.CudaContext;
+import org.nd4j.common.primitives.Pair;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.deeplearning4j.nn.workspace.ArrayType;
+import org.nd4j.common.util.ArrayUtil;
+
+import java.util.Collections;
+import java.util.Map;
+
+import org.bytedeco.cuda.cudart.*;
+import org.bytedeco.cuda.cudnn.*;
+
+import static org.bytedeco.cuda.global.cudnn.*;
+
+/**
+ * cuDNN-based helper for the local response normalization layer.
+ *
+ * @author saudet
+ */
+@Slf4j
+public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper implements LocalResponseNormalizationHelper {
+
+ public CudnnLocalResponseNormalizationHelper(DataType dataType) {
+ super(dataType);
+ }
+
+ private static class CudnnLocalResponseNormalizationContext extends CudnnContext {
+
+ private static class Deallocator extends CudnnLocalResponseNormalizationContext implements Pointer.Deallocator {
+ Deallocator(CudnnLocalResponseNormalizationContext c) {
+ super(c);
+ }
+
+ @Override
+ public void deallocate() {
+ destroyHandles();
+ }
+ }
+
+ private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
+ deltaTensorDesc = new cudnnTensorStruct();
+ private cudnnLRNStruct lrnDesc = new cudnnLRNStruct();
+
+ public CudnnLocalResponseNormalizationContext() {
+ createHandles();
+ deallocator(new Deallocator(this));
+ }
+
+ public CudnnLocalResponseNormalizationContext(CudnnLocalResponseNormalizationContext c) {
+ super(c);
+ srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
+ dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
+ deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
+ lrnDesc = new cudnnLRNStruct(c.lrnDesc);
+ }
+
+ @Override
+ protected void createHandles() {
+ super.createHandles();
+ checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
+ checkCudnn(cudnnCreateLRNDescriptor(lrnDesc));
+ }
+
+ @Override
+ protected void destroyHandles() {
+ checkCudnn(cudnnDestroyLRNDescriptor(lrnDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
+ super.destroyHandles();
+ }
+ }
+
+ private CudnnLocalResponseNormalizationContext cudnnContext = new CudnnLocalResponseNormalizationContext();
+ private INDArray activations = null;
+
+ public boolean checkSupported(double k, double n, double alpha, double beta) {
+ boolean supported = checkSupported();
+ if (n < CUDNN_LRN_MIN_N) {
+ supported = false;
+ log.warn("Not supported: n < CUDNN_LRN_MIN_N (" + n + " < " + CUDNN_LRN_MIN_N + ")");
+ }
+ if (n > CUDNN_LRN_MAX_N) {
+ supported = false;
+ log.warn("Not supported: n > CUDNN_LRN_MAX_N (" + n + " > " + CUDNN_LRN_MAX_N + ")");
+ }
+ if (k < CUDNN_LRN_MIN_K) {
+ supported = false;
+ log.warn("Not supported: k < CUDNN_LRN_MIN_K (" + k + " < " + CUDNN_LRN_MIN_K + ")");
+ }
+ if (beta < CUDNN_LRN_MIN_BETA) {
+ supported = false;
+ log.warn("Not supported: beta < CUDNN_LRN_MIN_BETA (" + beta + " < " + CUDNN_LRN_MIN_BETA + ")");
+ }
+ return supported;
+ }
+
+ @Override
+ public Pair backpropGradient(INDArray input, INDArray epsilon, double k, double n, double alpha,
+ double beta, LayerWorkspaceMgr workspaceMgr) {
+ val miniBatch = (int) input.size(0);
+ val depth = (int) input.size(1);
+ val inH = (int) input.size(2);
+ val inW = (int) input.size(3);
+
+ Gradient retGradient = new DefaultGradient();
+
+ if (!Shape.hasDefaultStridesForShape(epsilon)) {
+ // apparently not supported by cuDNN
+ epsilon = epsilon.dup('c');
+ }
+
+ val srcStride = ArrayUtil.toInts(input.stride());
+ val deltaStride = ArrayUtil.toInts(epsilon.stride());
+
+ if (Nd4j.getExecutioner() instanceof GridExecutioner)
+ ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
+
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, depth, inH, inW,
+ srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, miniBatch, depth, inH, inW,
+ deltaStride[0], deltaStride[1], deltaStride[2], deltaStride[3]));
+ checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k));
+
+ INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c');
+
+ val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
+ dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context =
+ allocator.getFlowController().prepareActionAllWrite(input, epsilon, activations, nextEpsilon);
+ Pointer srcData = allocator.getPointer(input, context);
+ Pointer epsData = allocator.getPointer(epsilon, context);
+ Pointer zData = allocator.getPointer(activations, context);
+ Pointer dstData = allocator.getPointer(nextEpsilon, context);
+
+ checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
+ checkCudnn(cudnnLRNCrossChannelBackward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
+ this.alpha, cudnnContext.deltaTensorDesc, zData, cudnnContext.deltaTensorDesc, epsData,
+ cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, dstData));
+
+ allocator.getFlowController().registerActionAllWrite(context, input, epsilon, activations, nextEpsilon);
+
+ if (CudaEnvironment.getInstance().getConfiguration().isDebug())
+ context.syncOldStream();
+
+ return new Pair<>(retGradient, nextEpsilon);
+ }
+
+
+ @Override
+ public INDArray activate(INDArray input, boolean training, double k, double n, double alpha, double beta, LayerWorkspaceMgr workspaceMgr) {
+ val miniBatch = (int) input.size(0);
+ val inDepth = (int) input.size(1);
+ val inH = (int) input.size(2);
+ val inW = (int) input.size(3);
+
+ if(!Shape.hasDefaultStridesForShape(input)){
+ input = input.dup('c');
+ }
+
+ val srcStride = ArrayUtil.toInts(input.stride());
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
+ srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
+
+ activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c');
+
+ val dstStride = ArrayUtil.toInts(activations.stride());
+ checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
+ dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
+ checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k));
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, activations);
+ Pointer srcData = allocator.getPointer(input, context);
+ Pointer dstData = allocator.getPointer(activations, context);
+
+ if (Nd4j.getExecutioner() instanceof GridExecutioner)
+ ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
+
+ checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
+ checkCudnn(cudnnLRNCrossChannelForward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
+ this.alpha, cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc,
+ dstData));
+
+ allocator.getFlowController().registerActionAllWrite(context, input, activations);
+
+ if (CudaEnvironment.getInstance().getConfiguration().isDebug())
+ context.syncOldStream();
+
+ return activations;
+ }
+
+ @Override
+ public Map helperMemoryUse() {
+ //No persistent memory use other than the structs (which are small)
+ return Collections.emptyMap();
+ }
+}
diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java
new file mode 100644
index 000000000..120078d07
--- /dev/null
+++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java
@@ -0,0 +1,659 @@
+/*
+ * ******************************************************************************
+ * *
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.deeplearning4j.cuda.recurrent;
+
+import lombok.extern.slf4j.Slf4j;
+import lombok.val;
+import com.jakewharton.byteunits.BinaryByteUnit;
+import org.bytedeco.javacpp.Pointer;
+import org.deeplearning4j.nn.api.Layer;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.gradient.DefaultGradient;
+import org.deeplearning4j.nn.gradient.Gradient;
+import org.deeplearning4j.cuda.BaseCudnnHelper;
+import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
+import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
+import org.nd4j.jita.allocator.Allocator;
+import org.nd4j.jita.allocator.impl.AtomicAllocator;
+import org.nd4j.linalg.activations.IActivation;
+import org.nd4j.linalg.activations.impl.ActivationSigmoid;
+import org.nd4j.linalg.activations.impl.ActivationTanH;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.shape.Shape;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.jcublas.context.CudaContext;
+import org.nd4j.common.primitives.Pair;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.deeplearning4j.nn.workspace.ArrayType;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.bytedeco.cuda.cudart.*;
+import org.bytedeco.cuda.cudnn.*;
+import static org.bytedeco.cuda.global.cudart.*;
+import static org.bytedeco.cuda.global.cudnn.*;
+
+/**
+ * cuDNN-based helper for the recurrent LSTM layer (no peephole connections).
+ *
+ * @author saudet
+ */
+@Slf4j
+public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
+
+ public CudnnLSTMHelper(DataType dataType) {
+ super(dataType);
+ }
+
+ private static class CudnnLSTMContext extends CudnnContext {
+
+ private static class Deallocator extends CudnnLSTMContext implements Pointer.Deallocator {
+ Deallocator(CudnnLSTMContext c) {
+ super(c);
+ }
+
+ @Override
+ public void deallocate() {
+ destroyHandles();
+ }
+ }
+
+ private cudnnTensorStruct hxDesc = new cudnnTensorStruct(), cxDesc = new cudnnTensorStruct();
+ private cudnnTensorStruct hyDesc = new cudnnTensorStruct(), cyDesc = new cudnnTensorStruct();
+ private cudnnTensorStruct dhxDesc = new cudnnTensorStruct(), dcxDesc = new cudnnTensorStruct();
+ private cudnnTensorStruct dhyDesc = new cudnnTensorStruct(), dcyDesc = new cudnnTensorStruct();
+
+ private cudnnFilterStruct wDesc = new cudnnFilterStruct(), dwDesc = new cudnnFilterStruct();
+ private cudnnFilterStruct linLayerMatDesc = new cudnnFilterStruct(), linLayerBiasDesc = new cudnnFilterStruct();
+
+ private cudnnRNNStruct rnnDesc = new cudnnRNNStruct();
+ private cudnnDropoutStruct dropoutDesc = new cudnnDropoutStruct();
+ private cudnnActivationStruct activationDesc = new cudnnActivationStruct();
+
+ public CudnnLSTMContext() {
+ createHandles();
+ deallocator(new Deallocator(this));
+ }
+
+ public CudnnLSTMContext(CudnnLSTMContext c) {
+ super(c);
+ hxDesc = new cudnnTensorStruct(c.hxDesc);
+ cxDesc = new cudnnTensorStruct(c.cxDesc);
+ hyDesc = new cudnnTensorStruct(c.hyDesc);
+ cyDesc = new cudnnTensorStruct(c.cyDesc);
+ dhxDesc = new cudnnTensorStruct(c.dhxDesc);
+ dcxDesc = new cudnnTensorStruct(c.dcxDesc);
+ dhyDesc = new cudnnTensorStruct(c.dhyDesc);
+ dcyDesc = new cudnnTensorStruct(c.dcyDesc);
+
+ wDesc = new cudnnFilterStruct(c.wDesc);
+ dwDesc = new cudnnFilterStruct(c.dwDesc);
+ linLayerMatDesc = new cudnnFilterStruct(c.linLayerMatDesc);
+ linLayerBiasDesc = new cudnnFilterStruct(c.linLayerBiasDesc);
+
+ rnnDesc = new cudnnRNNStruct(c.rnnDesc);
+ dropoutDesc = new cudnnDropoutStruct(c.dropoutDesc);
+ activationDesc = new cudnnActivationStruct(c.activationDesc);
+ }
+
+ @Override
+ protected void createHandles() {
+ super.createHandles();
+
+ checkCudnn(cudnnCreateTensorDescriptor(hxDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(cxDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(hyDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(cyDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(dhxDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(dcxDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(dhyDesc));
+ checkCudnn(cudnnCreateTensorDescriptor(dcyDesc));
+
+ checkCudnn(cudnnCreateFilterDescriptor(wDesc));
+ checkCudnn(cudnnCreateFilterDescriptor(dwDesc));
+ checkCudnn(cudnnCreateFilterDescriptor(linLayerMatDesc));
+ checkCudnn(cudnnCreateFilterDescriptor(linLayerBiasDesc));
+
+ checkCudnn(cudnnCreateRNNDescriptor(rnnDesc));
+ checkCudnn(cudnnCreateDropoutDescriptor(dropoutDesc));
+ checkCudnn(cudnnCreateActivationDescriptor(activationDesc));
+ }
+
+ @Override
+ protected void destroyHandles() {
+ checkCudnn(cudnnDestroyActivationDescriptor(activationDesc));
+ checkCudnn(cudnnDestroyDropoutDescriptor(dropoutDesc));
+ checkCudnn(cudnnDestroyRNNDescriptor(rnnDesc));
+
+ checkCudnn(cudnnDestroyFilterDescriptor(wDesc));
+ checkCudnn(cudnnDestroyFilterDescriptor(dwDesc));
+ checkCudnn(cudnnDestroyFilterDescriptor(linLayerMatDesc));
+ checkCudnn(cudnnDestroyFilterDescriptor(linLayerBiasDesc));
+
+ checkCudnn(cudnnDestroyTensorDescriptor(hxDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(cxDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(hyDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(cyDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(dhxDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(dcxDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(dhyDesc));
+ checkCudnn(cudnnDestroyTensorDescriptor(dcyDesc));
+
+ super.destroyHandles();
+ }
+ }
+
+ // These constants might eventually become variable parameters...
+ protected static final int NUM_LAYERS = 1;
+ protected static final float DROPOUT = 0;
+ protected static final boolean BIDIRECTIONAL = false;
+ protected static final int RNN_MODE = CUDNN_LSTM;
+ protected static final int NUM_LINEAR_LAYERS = 8; // CUDNN_LSTM
+
+ private CudnnLSTMContext cudnnContext = new CudnnLSTMContext();
+ private TensorArray xDesc = new TensorArray();
+ private TensorArray yDesc = new TensorArray();
+ private TensorArray dxDesc = new TensorArray();
+ private TensorArray dyDesc = new TensorArray();
+ private DataCache stateSpace = new DataCache();
+ private DataCache reserveSpace = new DataCache();
+ private DataCache weightsSpace = new DataCache();
+
+ private boolean initializedDropoutDescriptor = false;
+
+ private static INDArray toCOrder(INDArray arr) {
+ if (arr.isView() || arr.ordering() != 'c' || !Shape.strideDescendingCAscendingF(arr)) {
+ arr = arr.dup('c');
+ }
+ return arr;
+ }
+
+ @Override
+ public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn,
+ boolean hasPeepholeConnections) {
+ boolean supported = checkSupported();
+ if (!(gateActivationFn instanceof ActivationSigmoid)) {
+ supported = false;
+ log.warn("Not supported: Gate activation functions != ActivationSigmoid");
+ }
+ if (!(activationFn instanceof ActivationTanH)) {
+ supported = false;
+ log.warn("Not supported: Layer activation functions != ActivationTanH");
+ }
+ if (hasPeepholeConnections) {
+ supported = false;
+ log.warn("Not supported: LSTM layers with peephole connections");
+ }
+ return supported;
+ }
+
+ @Override
+ public Pair backpropGradient(final NeuralNetConfiguration conf,
+ final IActivation gateActivationFn, final INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
+ final INDArray inputWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
+ final INDArray epsilon, final boolean truncatedBPTT, final int tbpttBackwardLength,
+ final FwdPassReturn fwdPass, final boolean forwards, final String inputWeightKey,
+ final String recurrentWeightKey, final String biasWeightKey,
+ final Map gradientViews, INDArray maskArray, //Input mask: should only be used with bidirectional RNNs + variable length
+ final boolean hasPeepholeConnections, //True for GravesLSTM, false for LSTM
+ final LayerWorkspaceMgr workspaceMgr) {
+
+ //Expect errors to have shape: [miniBatchSize,n^(L+1),timeSeriesLength]
+ val hiddenLayerSize = recurrentWeights.size(0); //i.e., n^L
+ val prevLayerSize = inputWeights.size(0); //n^(L-1)
+ val inputLayerSize = input.size(1);
+ val miniBatchSize = epsilon.size(0);
+ boolean is2dInput = epsilon.rank() < 3; //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]
+ long timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
+
+ INDArray x = toCOrder(input.permute(2, 0, 1));
+ INDArray dy = toCOrder(epsilon.permute(2, 0, 1));
+ INDArray dx = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, inputWeights.dataType(), new long[] {timeSeriesLength, miniBatchSize, prevLayerSize}, 'c');
+
+ INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
+ INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey); //Order: {I,F,O,G}
+ INDArray bGradientsOut = gradientViews.get(biasWeightKey);
+
+ INDArray outputActivations = toCOrder(fwdPass.fwdPassOutput.permute(2, 0, 1));
+ INDArray prevStepMemCellState = toCOrder(fwdPass.prevMemCell);
+ INDArray prevStepActivations = toCOrder(fwdPass.prevAct);
+
+ Nd4j.getExecutioner().commit();
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareActionAllWrite(x, dy, dx, outputActivations,
+ prevStepMemCellState, prevStepActivations, iwGradientsOut, rwGradientsOut, bGradientsOut);
+ Pointer xData = allocator.getPointer(x, context);
+ Pointer dyData = allocator.getPointer(dy, context);
+ Pointer dxData = allocator.getPointer(dx, context);
+ Pointer outputActivationsData = allocator.getPointer(outputActivations, context);
+ Pointer prevMemCellStateData = allocator.getPointer(prevStepMemCellState, context);
+ Pointer prevStepActivationsData = allocator.getPointer(prevStepActivations, context);
+ Pointer iwGradientsOutData = allocator.getPointer(iwGradientsOut, context);
+ Pointer rwGradientsOutData = allocator.getPointer(rwGradientsOut, context);
+ Pointer bGradientsOutData = allocator.getPointer(bGradientsOut, context);
+
+ CUstream_st stream = new CUstream_st(context.getCublasStream());
+ checkCudnn(cudnnSetStream(cudnnContext, stream));
+
+ if (truncatedBPTT) {
+ val endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength) * miniBatchSize * hiddenLayerSize;
+ xData.position(endIdx * dataTypeSize);
+ dyData.position(endIdx * (BIDIRECTIONAL ? 2 : 1) * dataTypeSize);
+ outputActivationsData.position(endIdx * (BIDIRECTIONAL ? 2 : 1) * dataTypeSize);
+ timeSeriesLength = (int) Math.min(timeSeriesLength, tbpttBackwardLength);
+ }
+
+ cudnnTensorStruct xDesc0 = xDesc.get(cudnnTensorStruct.class, 0);
+
+ DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
+ checkCudnn(cudnnRNNBackwardData(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, yDesc,
+ outputActivationsData, dyDesc, dyData, cudnnContext.dhyDesc, null, cudnnContext.dcyDesc, null,
+ cudnnContext.wDesc, weightsSpace, cudnnContext.hxDesc, prevStepActivationsData, //hx: initial hidden state of RNN
+ cudnnContext.cxDesc, prevMemCellStateData, //cx: initial cell state of RNN
+ dxDesc, dxData, //dx: gradient at input of each time step
+ cudnnContext.dhxDesc, null, //dhx: gradient at initial hidden state of RNN
+ cudnnContext.dcxDesc, null, //dcx: Gradient at initial cell state
+ workSpace, workSpace.limit(), reserveSpace, reserveSpace.limit()));
+
+ // cudnnRNNBackwardWeights adds to the data in dW.
+ checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream));
+
+ checkCudnn(cudnnRNNBackwardWeights(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData, //Input data
+ cudnnContext.hxDesc, prevStepActivationsData, //Initial hidden state
+ yDesc, outputActivationsData, //Output data
+ workSpace, workSpace.limit(), cudnnContext.dwDesc, weightsSpace, reserveSpace,
+ reserveSpace.limit()));
+
+ int[] dataType = new int[1];
+ int[] format = new int[1];
+ int[] nbDims = new int[1];
+ int[] filterDimA = new int[3];
+ Pointer linLayerMat = new Pointer();
+ Pointer linLayerBias = new Pointer();
+
+ for (int layer = 0; layer < NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1); layer++) {
+ for (int linLayerID = 0; linLayerID < NUM_LINEAR_LAYERS; linLayerID++) {
+ checkCudnn(cudnnGetRNNLinLayerMatrixParams(cudnnContext, cudnnContext.rnnDesc, layer, xDesc0,
+ cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerMatDesc,
+ linLayerMat));
+
+ checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerMatDesc, 3, dataType, format, nbDims,
+ filterDimA));
+
+ checkCudnn(cudnnGetRNNLinLayerBiasParams(cudnnContext, cudnnContext.rnnDesc, layer, xDesc0,
+ cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerBiasDesc,
+ linLayerBias));
+
+ checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerBiasDesc, 3, dataType, format, nbDims,
+ filterDimA));
+
+ // our data is in "new, forget, output, and input gates" order (aka IFOG), each kind of weight packed together
+ int position = 0;
+ long size = 0;
+ Pointer data = null;
+ switch (linLayerID) {
+ case 0:
+ data = iwGradientsOutData;
+ position = 3;
+ size = inputLayerSize;
+ break; // input gate
+ case 1:
+ data = iwGradientsOutData;
+ position = 1;
+ size = inputLayerSize;
+ break; // forget gate
+ case 2:
+ data = iwGradientsOutData;
+ position = 0;
+ size = inputLayerSize;
+ break; // new gate (input modulation gate)
+ case 3:
+ data = iwGradientsOutData;
+ position = 2;
+ size = inputLayerSize;
+ break; // output gate
+ case 4:
+ data = rwGradientsOutData;
+ position = 3;
+ size = hiddenLayerSize;
+ break; // input gate
+ case 5:
+ data = rwGradientsOutData;
+ position = 1;
+ size = hiddenLayerSize;
+ break; // forget gate
+ case 6:
+ data = rwGradientsOutData;
+ position = 0;
+ size = hiddenLayerSize;
+ break; // new gate (input modulation gate)
+ case 7:
+ data = rwGradientsOutData;
+ position = 2;
+ size = hiddenLayerSize;
+ break; // output gate
+ default:
+ throw new RuntimeException();
+ }
+ checkCuda(cudaMemcpyAsync(data.position(position * size * hiddenLayerSize * dataTypeSize), linLayerMat,
+ size * hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
+ if (linLayerID < 4) {
+ checkCuda(cudaMemcpyAsync(bGradientsOutData.position(position * hiddenLayerSize * dataTypeSize),
+ linLayerBias, hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
+ }
+ }
+ }
+
+ allocator.getFlowController().registerActionAllWrite(context, x, dy, dx, outputActivations,
+ prevStepMemCellState, prevStepActivations, iwGradientsOut, rwGradientsOut, bGradientsOut);
+
+ Gradient retGradient = new DefaultGradient();
+ retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
+ retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
+ retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
+
+ INDArray epsilonNext = dx.permute(1, 2, 0); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]
+
+ return new Pair<>(retGradient, epsilonNext);
+ }
+
+ @Override
+ public FwdPassReturn activate(final Layer layer, final NeuralNetConfiguration conf,
+ final IActivation gateActivationFn, //Activation function for the gates - sigmoid or hard sigmoid (must be found in range 0 to 1)
+ INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
+ final INDArray inputWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
+ final INDArray biases, //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
+ final boolean training, final INDArray prevOutputActivations, final INDArray prevMemCellState,
+ boolean forBackprop, boolean forwards, final String inputWeightKey, INDArray maskArray, //Input mask: should only be used with bidirectional RNNs + variable length
+ final boolean hasPeepholeConnections, //True for GravesLSTM, false for LSTM
+ final LayerWorkspaceMgr workspaceMgr) {
+
+ boolean is2dInput = input.rank() < 3; //Edge case of T=1, may have shape [m,nIn], equiv. to [m,nIn,1]
+ val timeSeriesLength = (is2dInput ? 1 : input.size(2));
+ val hiddenLayerSize = recurrentWeights.size(0);
+ val miniBatchSize = input.size(0);
+ val inputLayerSize = input.size(1);
+
+ INDArray x = toCOrder(input.permute(2, 0, 1));
+ INDArray linInputWeights = inputWeights;
+ INDArray linRecurrentWeights = recurrentWeights;
+ INDArray linBiases = biases;
+
+ INDArray prevAct = toCOrder(prevOutputActivations);
+ INDArray prevMemCell = toCOrder(prevMemCellState);
+
+ INDArray outputActivations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS,
+ inputWeights.dataType(), new long[] {timeSeriesLength, miniBatchSize, hiddenLayerSize * (BIDIRECTIONAL ? 2 : 1)}, 'c');
+ INDArray finalMemCellState = Nd4j.createUninitialized( inputWeights.dataType(),
+ new long[] {/*numLayers * (bidirectional ? 2 : 1),*/ miniBatchSize, hiddenLayerSize}, 'c');
+ INDArray finalStepActivations = Nd4j.createUninitialized( inputWeights.dataType(),
+ new long[] {/*numLayers * (bidirectional ? 2 : 1),*/ miniBatchSize, hiddenLayerSize}, 'c');
+
+ FwdPassReturn toReturn = new FwdPassReturn();
+ toReturn.prevAct = prevAct;
+ toReturn.prevMemCell = prevMemCell;
+
+ Nd4j.getExecutioner().commit();
+
+
+
+ if (timeSeriesLength > xDesc.capacity()) {
+ xDesc.deallocate();
+ xDesc = new TensorArray(timeSeriesLength);
+ }
+ if (timeSeriesLength > yDesc.capacity()) {
+ yDesc.deallocate();
+ yDesc = new TensorArray(timeSeriesLength);
+ }
+ if (timeSeriesLength > dxDesc.capacity()) {
+ dxDesc.deallocate();
+ dxDesc = new TensorArray(timeSeriesLength);
+ }
+ if (timeSeriesLength > dyDesc.capacity()) {
+ dyDesc.deallocate();
+ dyDesc = new TensorArray(timeSeriesLength);
+ }
+
+ for (int i = 0; i < timeSeriesLength; i++) {
+ int[] dimA = {(int) miniBatchSize, (int) inputLayerSize, 1};
+ int[] strideA = {(int) dimA[2] * dimA[1], dimA[2], 1};
+
+ checkCudnn(cudnnSetTensorNdDescriptor(xDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimA, strideA));
+ checkCudnn(cudnnSetTensorNdDescriptor(dxDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimA, strideA));
+
+ int[] dimB = {(int) miniBatchSize, (int) hiddenLayerSize * (BIDIRECTIONAL ? 2 : 1), 1};
+ int[] strideB = {dimB[2] * dimB[1], dimB[2], 1};
+
+ checkCudnn(cudnnSetTensorNdDescriptor(yDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimB, strideB));
+ checkCudnn(cudnnSetTensorNdDescriptor(dyDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimB, strideB));
+ }
+
+ int[] dimC = {NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1), (int) miniBatchSize, (int) hiddenLayerSize};
+ int[] strideC = {dimC[2] * dimC[1], dimC[2], 1};
+
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.hxDesc, dataType, 3, dimC, strideC));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.cxDesc, dataType, 3, dimC, strideC));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.hyDesc, dataType, 3, dimC, strideC));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.cyDesc, dataType, 3, dimC, strideC));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dhxDesc, dataType, 3, dimC, strideC));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dcxDesc, dataType, 3, dimC, strideC));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dhyDesc, dataType, 3, dimC, strideC));
+ checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dcyDesc, dataType, 3, dimC, strideC));
+
+ checkCudnn(cudnnDropoutGetStatesSize(cudnnContext, sizeInBytes));
+ long stateSize = sizeInBytes.get(0);
+ if (stateSize > stateSpace.capacity()) {
+ stateSpace.deallocate();
+ stateSpace = new DataCache(stateSize);
+ }
+ stateSpace.limit(stateSize);
+
+ if(!initializedDropoutDescriptor) {
+ checkCudnn(cudnnSetDropoutDescriptor(cudnnContext.dropoutDesc, cudnnContext, DROPOUT, stateSpace, stateSize,
+ Nd4j.getRandom().getSeed()));
+ }
+
+ checkCudnn(cudnnSetRNNDescriptor_v6(cudnnContext, cudnnContext.rnnDesc, (int) hiddenLayerSize, NUM_LAYERS, cudnnContext.dropoutDesc,
+ CUDNN_LINEAR_INPUT, BIDIRECTIONAL ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, RNN_MODE,
+ CUDNN_RNN_ALGO_STANDARD, dataType));
+
+ cudnnTensorStruct xDesc0 = xDesc.get(cudnnTensorStruct.class, 0);
+ checkCudnn(cudnnGetRNNParamsSize(cudnnContext, cudnnContext.rnnDesc, xDesc0, sizeInBytes, dataType));
+ long weightsSize = sizeInBytes.get(0);
+ if (weightsSize > weightsSpace.capacity()) {
+ weightsSpace.deallocate();
+ weightsSpace = new DataCache(weightsSize);
+ }
+ weightsSpace.limit(weightsSize);
+
+ int[] dimW = {(int) weightsSize / dataTypeSize, 1, 1};
+
+ checkCudnn(cudnnSetFilterNdDescriptor(cudnnContext.wDesc, dataType, CUDNN_TENSOR_NCHW, 3, dimW));
+ checkCudnn(cudnnSetFilterNdDescriptor(cudnnContext.dwDesc, dataType, CUDNN_TENSOR_NCHW, 3, dimW));
+
+ checkCudnn(cudnnGetRNNWorkspaceSize(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, sizeInBytes));
+ long workSize = sizeInBytes.get(0);
+ DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
+ if (workSpace == null || workSize > workSpace.capacity()) {
+ if(log.isTraceEnabled()){
+ if(workSpace == null){
+ log.trace("CudnnLSTMHelper activate: Allocating initial workspace of size {} ({})", workSize,
+ BinaryByteUnit.format(workSize, "#.00"));
+ } else {
+ log.trace("CudnnLSTMHelper activate: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})",
+ workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"),
+ workSize, BinaryByteUnit.format(workSize, "#.00"));
+ }
+ }
+ if(workSpace != null)
+ workSpace.deallocate();
+ workSpace = new DataCache(workSize);
+ workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
+ }
+ workSpace.limit(workSize);
+
+ checkCudnn(cudnnGetRNNTrainingReserveSize(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc,
+ sizeInBytes));
+ long reserveSize = sizeInBytes.get(0);
+ if (reserveSize > reserveSpace.capacity()) {
+ reserveSpace.deallocate();
+ reserveSpace = new DataCache(reserveSize);
+ }
+ reserveSpace.limit(reserveSize);
+
+ Allocator allocator = AtomicAllocator.getInstance();
+ CudaContext context = allocator.getFlowController().prepareActionAllWrite(x, linInputWeights,
+ linRecurrentWeights, linBiases, prevAct, prevMemCell, outputActivations, finalMemCellState,
+ finalStepActivations);
+ Pointer xData = allocator.getPointer(x, context);
+ Pointer linInputWeightsData = allocator.getPointer(linInputWeights, context);
+ Pointer linRecurrentWeightsData = allocator.getPointer(linRecurrentWeights, context);
+ Pointer linBiasesData = allocator.getPointer(linBiases, context);
+ Pointer prevActData = allocator.getPointer(prevAct, context);
+ Pointer prevMemCellData = allocator.getPointer(prevMemCell, context);
+ Pointer outputActivationsData = allocator.getPointer(outputActivations, context);
+ Pointer finalMemCellStateData = allocator.getPointer(finalMemCellState, context);
+ Pointer finalTimeStepActivationsData = allocator.getPointer(finalStepActivations, context);
+
+ CUstream_st stream = new CUstream_st(context.getCublasStream());
+ checkCudnn(cudnnSetStream(cudnnContext, stream));
+
+ checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream));
+
+ int[] dataType = new int[1];
+ int[] format = new int[1];
+ int[] nbDims = new int[1];
+ int[] filterDimA = new int[3];
+ Pointer linLayerMat = new Pointer();
+ Pointer linLayerBias = new Pointer();
+
+ for (int layerNum = 0; layerNum < NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1); layerNum++) {
+ for (int linLayerID = 0; linLayerID < NUM_LINEAR_LAYERS; linLayerID++) {
+ checkCudnn(cudnnGetRNNLinLayerMatrixParams(cudnnContext, cudnnContext.rnnDesc, layerNum, xDesc0,
+ cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerMatDesc,
+ linLayerMat));
+
+ checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerMatDesc, 3, dataType, format, nbDims,
+ filterDimA));
+
+ checkCudnn(cudnnGetRNNLinLayerBiasParams(cudnnContext, cudnnContext.rnnDesc, layerNum, xDesc0,
+ cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerBiasDesc,
+ linLayerBias));
+
+ checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerBiasDesc, 3, dataType, format, nbDims,
+ filterDimA));
+
+ // our data is in "new, forget, output, and input gates" order (aka IFOG), each kind of weight packed together
+ int position = 0;
+ long size = 0;
+ Pointer data = null;
+ switch (linLayerID) {
+ case 0:
+ data = linInputWeightsData;
+ position = 3;
+ size = inputLayerSize;
+ break; // input gate
+ case 1:
+ data = linInputWeightsData;
+ position = 1;
+ size = inputLayerSize;
+ break; // forget gate
+ case 2:
+ data = linInputWeightsData;
+ position = 0;
+ size = inputLayerSize;
+ break; // new gate
+ case 3:
+ data = linInputWeightsData;
+ position = 2;
+ size = inputLayerSize;
+ break; // output gate
+ case 4:
+ data = linRecurrentWeightsData;
+ position = 3;
+ size = hiddenLayerSize;
+ break; // input gate
+ case 5:
+ data = linRecurrentWeightsData;
+ position = 1;
+ size = hiddenLayerSize;
+ break; // forget gate
+ case 6:
+ data = linRecurrentWeightsData;
+ position = 0;
+ size = hiddenLayerSize;
+ break; // new gate
+ case 7:
+ data = linRecurrentWeightsData;
+ position = 2;
+ size = hiddenLayerSize;
+ break; // output gate
+ default:
+ throw new RuntimeException();
+ }
+ checkCuda(cudaMemcpyAsync(linLayerMat, data.position(position * size * hiddenLayerSize * dataTypeSize),
+ size * hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
+ if (linLayerID < 4) {
+ checkCuda(cudaMemcpyAsync(linLayerBias,
+ linBiasesData.position(position * hiddenLayerSize * dataTypeSize),
+ hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
+ }
+ }
+ }
+
+ if (training) {
+ checkCudnn(cudnnRNNForwardTraining(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData,
+ cudnnContext.hxDesc, prevActData, cudnnContext.cxDesc, prevMemCellData, cudnnContext.wDesc,
+ weightsSpace, yDesc, outputActivationsData, cudnnContext.hyDesc,
+ finalTimeStepActivationsData, cudnnContext.cyDesc, finalMemCellStateData, workSpace,
+ workSpace.limit(), reserveSpace, reserveSpace.limit()));
+ } else {
+ checkCudnn(cudnnRNNForwardInference(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData,
+ cudnnContext.hxDesc, prevActData, cudnnContext.cxDesc, prevMemCellData, cudnnContext.wDesc,
+ weightsSpace, yDesc, outputActivationsData, cudnnContext.hyDesc,
+ finalTimeStepActivationsData, cudnnContext.cyDesc, finalMemCellStateData, workSpace,
+ workSpace.limit()));
+ }
+
+ allocator.getFlowController().registerActionAllWrite(context, x, linInputWeights, linRecurrentWeights,
+ linBiases, prevAct, prevMemCell, outputActivations, finalMemCellState, finalStepActivations);
+
+ toReturn.fwdPassOutput = outputActivations.permute(1, 2, 0);
+ toReturn.lastAct = finalStepActivations;
+ toReturn.lastMemCell = finalMemCellState;
+ toReturn.prevAct = prevAct;
+ toReturn.prevMemCell = prevMemCell;
+
+ return toReturn;
+ }
+
+ @Override
+ public Map helperMemoryUse() {
+ Map memUse = new HashMap<>();
+ memUse.put("stateStace", stateSpace.capacity());
+ memUse.put("reserveSpace", reserveSpace.capacity());
+ memUse.put("weightsSpace", weightsSpace.capacity());
+ return memUse;
+ }
+}
diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java
index 37df1f31b..92ad3c679 100644
--- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java
+++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java
@@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.preprocessors;
import lombok.Data;
+import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
@@ -32,6 +33,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
@Slf4j
@Data
+@EqualsAndHashCode(callSuper=false)
public class KerasFlattenRnnPreprocessor extends BaseInputPreProcessor {
private long tsLength;
diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle
index c18c258ad..558d36365 100644
--- a/cavis-full/build.gradle
+++ b/cavis-full/build.gradle
@@ -22,7 +22,7 @@ dependencies {
&& !sproj.name.equals("Cavis")
&& !sproj.name.equals("cavis-datavec")
&& !sproj.name.equals("cavis-dnn")
- && !sproj.name.equals("cavis-native") && !sproj.name.equals("cavis-native-lib")
+ && !sproj.name.equals("cavis-native")
&& !sproj.name.equals("cavis-nd4j")
&& !sproj.name.equals("cavis-ui")
&& !sproj.name.equals("cavis-zoo")) {
@@ -31,7 +31,7 @@ dependencies {
}
// if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements")
// if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportApiElements")
-
+/*
api(projects.cavisNative.cavisNativeLib) {
capabilities {
//if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version)
@@ -44,7 +44,7 @@ dependencies {
//if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version)
}
}
-
+*/
//if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation")
//if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation")
//if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath")
diff --git a/cavis-native/cavis-native-lib/CMakeLists.txt b/cavis-native/cavis-native-lib/CMakeLists.txt
index 3795e7bd0..24360e856 100644
--- a/cavis-native/cavis-native-lib/CMakeLists.txt
+++ b/cavis-native/cavis-native-lib/CMakeLists.txt
@@ -121,7 +121,7 @@ endfunction()
if (SD_CUDA)
#enable_language(CUDA)
- find_package(CUDAToolkit 11.2 REQUIRED)
+ find_package(CUDAToolkit 11.4 REQUIRED)
message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}")
message(STATUS "CUDAToolkit_VERSION_MAJOR: ${CUDAToolkit_VERSION_MAJOR}")
message(STATUS "CUDAToolkit_VERSION_MINOR: ${CUDAToolkit_VERSION_MINOR}")
diff --git a/chooseBackend.gradle b/chooseBackend.gradle
index 7a3159f59..d1a33af9e 100644
--- a/chooseBackend.gradle
+++ b/chooseBackend.gradle
@@ -20,11 +20,9 @@
*/
ext {
chip = (properties.CAVIS_CHIP ?: "cuda,cpu").toLowerCase() //the default is to build for CPU and CUDA
- testChip = (properties.CAVIS_TEST_CHIP ?: " ").toLowerCase() //the default is without specific backend
- logger.debug("Building for chips ${chip} and running tests with backends for ${testChip}")
+ logger.debug("Building for chips ${chip} and running tests with backends for ${chip}")
chipList = chip.split(",")
- testChipList = testChip.split(",")
/* just for usability */
withCuda = { ->
@@ -33,10 +31,4 @@ ext {
withCpu = { ->
return chip.contains("cpu")
}
- withCudaTest = { ->
- return testChip.contains("cuda")
- }
- withCpuTest = { ->
- return testChip.contains("cpu")
- }
}
diff --git a/createTestBackends.gradle b/createTestBackends.gradle
index a0cef6c24..638e511e2 100644
--- a/createTestBackends.gradle
+++ b/createTestBackends.gradle
@@ -24,7 +24,7 @@ ext {
buildTarget = rootProject.ext.buildTarget
apply from: new File("${project.rootProject.projectDir}/chooseBackend.gradle")
- testChipList.each { thisChip ->
+ chipList.each { thisChip ->
configurations.register("${thisChip}TestImplementation") {
it.extendsFrom configurations.testImplementation, configurations.implementation
@@ -79,33 +79,44 @@ ext {
dependencies {
- if (withCudaTest()) {
+ if (withCuda()) {
cudaTestRuntime platform(projects.cavisCommonPlatform)
cudaTestRuntime projects.cavisNative.cavisNativeJcublas
+ cudaTestRuntime projects.cavisDnn.cavisDnnCudnn
cudaTestRuntime group: "org.bytedeco", name: "openblas"
cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
cudaTestRuntime group: "org.bytedeco", name: "cuda"
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
+ cudaTestRuntime (project( path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements"))
+ /*
cudaTestRuntime(project(":cavis-native:cavis-native-lib")) {
+
capabilities {
it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cuda-support:1.0.0-SNAPSHOT"
}
}
+
+ */
}
- if (withCpuTest()) {
+ if (withCpu()) {
cpuTestRuntime platform(projects.cavisCommonPlatform)
cpuTestRuntime projects.cavisNative.cavisNativeCpu
cpuTestRuntime group: "org.bytedeco", name: "openblas"
cpuTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
cpuTestRuntime group: "org.bytedeco", name: "opencv"
cpuTestRuntime group: "org.bytedeco", name: "opencv", classifier: buildTarget
+ cpuTestRuntime project( path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportRuntimeElements")
+ /*
cpuTestRuntime(project(":cavis-native:cavis-native-lib")) {
+
capabilities {
it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cpu-support:1.0.0-SNAPSHOT"
}
}
+
+ */
}
}
}
\ No newline at end of file
diff --git a/settings.gradle b/settings.gradle
index 17d2ee1b9..efde26230 100644
--- a/settings.gradle
+++ b/settings.gradle
@@ -89,6 +89,7 @@ include ':cavis-native:cavis-native-lib'
include ':cavis-native:cavis-native-common'
include ':cavis-dnn'
include ':cavis-dnn:cavis-dnn-api'
+include ':cavis-dnn:cavis-dnn-cudnn'
include ':cavis-dnn:cavis-dnn-common'
include ':cavis-dnn:cavis-dnn-common-tests'
include ':cavis-dnn:cavis-dnn-core'
@@ -151,3 +152,6 @@ include ':cavis-zoo'
include ':cavis-zoo:cavis-zoo-models'
include ':brutex-extended-tests'
include ':cavis-full'
+include 'cavis-dnn:cavis-dnn-cudnn'
+findProject(':cavis-dnn:cavis-dnn-cudnn')?.name = 'cavis-dnn-cudnn'
+