Adding cuDNN support

master
Brian Rosenberger 2023-03-10 11:20:32 +01:00
parent a39e44c782
commit aab7b423d1
28 changed files with 4441 additions and 360 deletions

View File

@ -19,8 +19,12 @@
* *
*/ */
apply plugin: 'java' plugins {
apply plugin: 'maven-publish' id 'java-library'
id 'maven-publish'
id 'com.github.johnrengelman.shadow' version '7.1.2'
}
apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
@ -54,6 +58,7 @@ dependencies {
implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkParameterserver implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkParameterserver
implementation projects.cavisDnn.cavisDnnNnParent.cavisDnnNnCore implementation projects.cavisDnn.cavisDnnNnParent.cavisDnnNnCore
implementation projects.cavisDnn.cavisDnnNn implementation projects.cavisDnn.cavisDnnNn
implementation projects.cavisUi.cavisUiCommon implementation projects.cavisUi.cavisUiCommon
implementation projects.cavisUi.cavisUiVertx implementation projects.cavisUi.cavisUiVertx
implementation projects.cavisUi.cavisUiModel implementation projects.cavisUi.cavisUiModel
@ -66,11 +71,21 @@ dependencies {
implementation projects.cavisDnn.cavisDnnParallelwrapper implementation projects.cavisDnn.cavisDnnParallelwrapper
implementation projects.cavisZoo.cavisZooModels implementation projects.cavisZoo.cavisZooModels
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT" testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
} }
test { test {
dependsOn jar enabled true
dependsOn shadowJar
} }
shadowJar {
enabled true;
zip64 true //need this to support jars with more than 65535 entries
archiveClassifier.set('all')
from sourceSets.test.output
}

View File

@ -0,0 +1,279 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Arrays;
public class App {
private static final double LEARNING_RATE = 0.0002;
private static final double GRADIENT_THRESHOLD = 100.0;
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
private static JFrame frame;
private static JPanel panel;
private static Layer[] genLayers() {
return new Layer[] {
new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DenseLayer.Builder().nIn(256).nOut(512).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DenseLayer.Builder().nIn(512).nOut(1024).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build()
};
}
/**
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
*
* @return config
*/
private static MultiLayerConfiguration generator() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.list(genLayers())
.build();
return conf;
}
private static Layer[] disLayers() {
return new Layer[]{
new DenseLayer.Builder().nIn(784).nOut(1024).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DropoutLayer.Builder(1 - 0.5).build(),
new DenseLayer.Builder().nIn(1024).nOut(512).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DropoutLayer.Builder(1 - 0.5).build(),
new DenseLayer.Builder().nIn(512).nOut(256).build(),
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
new DropoutLayer.Builder(1 - 0.5).build(),
new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
};
}
private static MultiLayerConfiguration discriminator() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.list(disLayers())
.build();
return conf;
}
private static MultiLayerConfiguration gan() {
Layer[] genLayers = genLayers();
Layer[] disLayers = Arrays.stream(disLayers())
.map((layer) -> {
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return new FrozenLayerWithBackprop(layer);
} else {
return layer;
}
}).toArray(Layer[]::new);
Layer[] layers = ArrayUtils.addAll(genLayers, disLayers);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.list(layers)
.build();
return conf;
}
@Test
public void runTest() throws Exception {
main();
}
public static void main(String... args) throws Exception {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 42);
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
gen.init();
dis.init();
gan.init();
copyParams(gen, dis, gan);
gen.setListeners(new PerformanceListener(10, true));
dis.setListeners(new PerformanceListener(10, true));
gan.setListeners(new PerformanceListener(10, true));
trainData.reset();
int j = 0;
for (int i = 0; i < 20; i++) {
while (trainData.hasNext()) {
j++;
// generate data
INDArray real = trainData.next().getFeatures().muli(2).subi(1);
int batchSize = (int) real.shape()[0];
INDArray fakeIn = Nd4j.rand(batchSize, 100);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data);
dis.fit(data);
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
if (j % 10 == 1) {
System.out.println("Iteration " + j + " Visualizing...");
INDArray[] samples = new INDArray[9];
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
for (int k = 0; k < 9; k++) {
INDArray input = fakeSet2.get(k).getFeatures();
//samples[k] = gen.output(input, false);
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
}
visualize(samples);
}
}
trainData.reset();
}
// Copy the GANs generator to gen.
updateGen(gen, gan);
gen.save(new File("mnist-mlp-generator.dlj"));
}
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = 0; i < gan.getLayers().length; i++) {
if (i < genLayerCount) {
gen.getLayer(i).setParams(gan.getLayer(i).params());
} else {
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
}
}
}
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).params());
}
}
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).params());
}
}
private static void visualize(INDArray[] samples) {
if (frame == null) {
frame = new JFrame();
frame.setTitle("Viz");
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
panel = new JPanel();
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
frame.add(panel, BorderLayout.CENTER);
frame.setVisible(true);
}
panel.removeAll();
for (INDArray sample : samples) {
panel.add(getImage(sample));
}
frame.revalidate();
frame.pack();
}
private static JLabel getImage(INDArray tensor) {
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
for (int i = 0; i < 784; i++) {
int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
return new JLabel(scaled);
}
}

View File

@ -0,0 +1,411 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Sgd;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
/**
* Implementation of vanilla Generative Adversarial Networks as introduced in https://arxiv.org/pdf/1406.2661.pdf.
* <p>
* A DL4J GAN is initialized from two networks: a generator and a discriminator and will build a third network,
* the GAN network, from the first two.
*
* @author Max Pumperla
*/
public class GAN {
private static final IUpdater UPDATER_ZERO = Sgd.builder().learningRate(0.0).build();
public interface DiscriminatorProvider {
MultiLayerNetwork provide(IUpdater updater);
}
protected Supplier<MultiLayerNetwork> generatorSupplier;
protected DiscriminatorProvider discriminatorSupplier;
protected MultiLayerNetwork generator;
protected MultiLayerNetwork discriminator;
protected MultiLayerNetwork gan;
protected int latentDim;
protected IUpdater updater;
protected IUpdater biasUpdater;
protected OptimizationAlgorithm optimizer;
protected GradientNormalization gradientNormalizer;
protected double gradientNormalizationThreshold;
protected WorkspaceMode trainingWorkSpaceMode;
protected WorkspaceMode inferenceWorkspaceMode;
protected CacheMode cacheMode;
protected long seed;
private Double[] discriminatorLearningRates;
public GAN(Builder builder) {
this.generatorSupplier = builder.generator;
this.discriminatorSupplier = builder.discriminator;
this.latentDim = builder.latentDimension;
this.updater = builder.iUpdater;
this.biasUpdater = builder.biasUpdater;
this.optimizer = builder.optimizationAlgo;
this.gradientNormalizer = builder.gradientNormalization;
this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
this.trainingWorkSpaceMode = builder.trainingWorkspaceMode;
this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode;
this.cacheMode = builder.cacheMode;
this.seed = builder.seed;
defineGan();
}
public MultiLayerNetwork getGenerator() {
return generator;
}
public MultiLayerNetwork getDiscriminator() {
return discriminator;
}
public Evaluation evaluateGan(DataSetIterator data) {
return gan.evaluate(data);
}
public Evaluation evaluateGan(DataSetIterator data, List<String> labelsList) {
return gan.evaluate(data, labelsList);
}
public void setGeneratorListeners(BaseTrainingListener[] listeners) {
generator.setListeners(listeners);
}
public void setDiscriminatorListeners(BaseTrainingListener[] listeners) {
discriminator.setListeners(listeners);
}
public void setGanListeners(BaseTrainingListener[] listeners) {
gan.setListeners(listeners);
}
public void fit(DataSetIterator realData, int numEpochs) {
for (int i = 0; i < numEpochs; i++) {
while (realData.hasNext()) {
// Get real images as features
DataSet next = realData.next();
fit(next);
}
realData.reset();
}
}
public void fit(DataSet next) {
int batchSize;
INDArray realImages = next.getFeatures().muli(2).subi(1);
batchSize = (int) realImages.shape()[0];
// Sample from latent space and let the generate create fake images.
INDArray randomLatentData = Nd4j.rand(new int[]{batchSize, latentDim});
INDArray fakeImages = generator.output(randomLatentData);
// Real images are marked as "0", fake images at "1".
DataSet realSet = new DataSet(realImages, Nd4j.zeros(batchSize, 1));
DataSet fakeSet = new DataSet(fakeImages, Nd4j.ones(batchSize, 1));
// Fit the discriminator on a combined batch of real and fake images.
DataSet combined = DataSet.merge(Arrays.asList(realSet, fakeSet));
/*for (int i = 0; i < discriminator.getLayers().length; i++) {
if (discriminatorLearningRates[i] != null) {
discriminator.setLearningRate(i, discriminatorLearningRates[i]);
}
}*/
discriminator.fit(combined);
//discriminator.fit(combined);
// Update the discriminator in the GAN network
updateGanWithDiscriminator();
// Generate a new set of adversarial examples and try to mislead the discriminator.
// by labeling the fake images as real images we reward the generator when it's output
// tricks the discriminator.
INDArray adversarialExamples = Nd4j.rand(new int[]{batchSize, latentDim});
INDArray misleadingLabels = Nd4j.zeros(batchSize, 1);
DataSet adversarialSet = new DataSet(adversarialExamples, misleadingLabels);
// Set learning rate of discriminator part of gan to zero.
/*for (int i = generator.getLayers().length; i < gan.getLayers().length; i++) {
gan.setLearningRate(i, 0.0);
}*/
// Fit the GAN on the adversarial set, trying to fool the discriminator by generating
// better fake images.
gan.fit(adversarialSet);
// Copy the GANs generator part to "generator".
updateGeneratorFromGan();
}
private void defineGan() {
generator = generatorSupplier.get();
generator.init();
Layer[] genLayers = generator.getLayers();
int numGenLayers = genLayers.length;
discriminator = discriminatorSupplier.provide(updater);
discriminator.init();
MultiLayerNetwork ganDiscriminator = discriminatorSupplier.provide(UPDATER_ZERO);
ganDiscriminator.init();
Layer[] disLayers = ganDiscriminator.getLayers();
Layer[] layers = ArrayUtils.addAll(genLayers, disLayers);
MultiLayerConfiguration genConf = generator.getLayerWiseConfigurations();
MultiLayerConfiguration disConf = ganDiscriminator.getLayerWiseConfigurations();
org.deeplearning4j.nn.conf.layers.Layer[] confLayers = new org.deeplearning4j.nn.conf.layers.Layer[layers.length];
Map<Integer, InputPreProcessor> preProcessors = new HashMap<>();
for (int i = 0; i < layers.length; i++) {
confLayers[i] = layers[i].conf().getLayer();
if (i < numGenLayers) {
preProcessors.put(i, genConf.getInputPreProcess(i));
} else {
preProcessors.put(i, disConf.getInputPreProcess(i - numGenLayers));
}
}
MultiLayerConfiguration ganConf = new NeuralNetConfiguration.Builder()
.seed(seed)
.updater(updater)
.biasUpdater(biasUpdater)
.optimizationAlgo(optimizer)
.gradientNormalization(gradientNormalizer)
.gradientNormalizationThreshold(gradientNormalizationThreshold)
.activation(Activation.IDENTITY)
.trainingWorkspaceMode(trainingWorkSpaceMode)
.inferenceWorkspaceMode(inferenceWorkspaceMode)
.cacheMode(cacheMode)
.list(confLayers)
.inputPreProcessors(preProcessors)
.build();
gan = new MultiLayerNetwork(ganConf);
gan.init();
// we lose proper init here, need to copy weights after
copyParamsToGan();
}
private void copyParamsToGan() {
int genLayerCount = generator.getLayers().length;
for (int i = 0; i < gan.getLayers().length; i++) {
if (i < genLayerCount) {
generator.getLayer(i).setParams(gan.getLayer(i).params());
} else {
discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
}
}
}
/**
* After the GAN has been trained on misleading images, we update the generator the
* new weights (we don't have to update the discriminator, as it is frozen in the GAN).
*/
private void updateGeneratorFromGan() {
for (int i = 0; i < generator.getLayers().length; i++) {
generator.getLayer(i).setParams(gan.getLayer(i).params());
}
}
/**
* After the discriminator has been trained, we update the respective parts of the GAN network
* as well.
*/
private void updateGanWithDiscriminator() {
int genLayerCount = generator.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).params());
}
}
/**
* GAN builder, used as a starting point for creating a MultiLayerConfiguration or
* ComputationGraphConfiguration.<br>
*/
public static class Builder implements Cloneable {
protected Supplier<MultiLayerNetwork> generator;
protected DiscriminatorProvider discriminator;
protected int latentDimension;
protected IUpdater iUpdater = new Sgd();
protected IUpdater biasUpdater = null;
protected long seed = System.currentTimeMillis();
protected OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
protected GradientNormalization gradientNormalization = GradientNormalization.None;
protected double gradientNormalizationThreshold = 1.0;
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
protected CacheMode cacheMode = CacheMode.NONE;
public Builder() {
}
/**
* Set the (fake) image generator of the GAN.
*
* @param generator MultilayerNetwork
* @return Builder
*/
public GAN.Builder generator(Supplier<MultiLayerNetwork> generator) {
this.generator = generator;
return this;
}
/**
* Set the image discriminator of the GAN.
*
* @param discriminator MultilayerNetwork
* @return Builder
*/
public GAN.Builder discriminator(DiscriminatorProvider discriminator) {
this.discriminator = discriminator;
return this;
}
/**
* Set the latent dimension, i.e. the input vector space dimension of the generator.
*
* @param latentDimension latent space input dimension.
* @return Builder
*/
public GAN.Builder latentDimension(int latentDimension) {
this.latentDimension = latentDimension;
return this;
}
/**
* Random number generator seed. Used for reproducibility between runs
*/
public GAN.Builder seed(long seed) {
this.seed = seed;
Nd4j.getRandom().setSeed(seed);
return this;
}
/**
* Optimization algorithm to use. Most common: OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT
*
* @param optimizationAlgo Optimization algorithm to use when training
*/
public GAN.Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) {
this.optimizationAlgo = optimizationAlgo;
return this;
}
/**
* Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam}
* or {@link org.nd4j.linalg.learning.config.Nesterovs}<br>
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
* value, and can be overridden on a per-layer basis.
*
* @param updater Updater to use
*/
public GAN.Builder updater(IUpdater updater) {
this.iUpdater = updater;
return this;
}
/**
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
* set by {@link #updater(IUpdater)}<br>
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
* value, and can be overridden on a per-layer basis.
*
* @param updater Updater to use for bias parameters
*/
public GAN.Builder biasUpdater(IUpdater updater) {
this.biasUpdater = updater;
return this;
}
/**
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc.
* See {@link GradientNormalization} for details<br>
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
* value, and can be overridden on a per-layer basis.
*
* @param gradientNormalization Type of normalization to use. Defaults to None.
* @see GradientNormalization
*/
public GAN.Builder gradientNormalization(GradientNormalization gradientNormalization) {
this.gradientNormalization = gradientNormalization;
return this;
}
/**
* Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
* GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue<br>
* Not used otherwise.<br>
* L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping.<br>
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
* value, and can be overridden on a per-layer basis.
*/
public GAN.Builder gradientNormalizationThreshold(double threshold) {
this.gradientNormalizationThreshold = threshold;
return this;
}
public GAN build() {
return new GAN(this);
}
}
}

View File

@ -0,0 +1,73 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import org.nd4j.linalg.api.ndarray.INDArray;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
public class GANVisualizationUtils {
public static JFrame initFrame() {
JFrame frame = new JFrame();
frame.setTitle("Viz");
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
return frame;
}
public static JPanel initPanel(JFrame frame, int numSamples) {
JPanel panel = new JPanel();
panel.setLayout(new GridLayout(numSamples / 3, 1, 8, 8));
frame.add(panel, BorderLayout.CENTER);
frame.setVisible(true);
return panel;
}
public static void visualize(INDArray[] samples, JFrame frame, JPanel panel) {
panel.removeAll();
for (int i = 0; i < samples.length; i++) {
panel.add(getImage(samples[i]));
}
frame.revalidate();
frame.pack();
}
private static JLabel getImage(INDArray tensor) {
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
for (int i = 0; i < 784; i++) {
int pixel = (int) (((tensor.getDouble(i) + 1) * 2) * 255);
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
return new JLabel(scaled);
}
}

View File

@ -0,0 +1,193 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.function.Supplier;
/**
* Training and visualizing a deep convolutional generative adversarial network (DCGAN) on handwritten digits.
*
* @author Max Pumperla, wmeddie
*/
public class MnistDCGANExample {
private static JFrame frame;
private static JPanel panel;
private static final int latentDim = 100;
private static final int height = 28;
private static final int width = 28;
private static final int channels = 1;
private static void visualize(INDArray[] samples) {
if (frame == null) {
frame = new JFrame();
frame.setTitle("Viz");
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
panel = new JPanel();
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
frame.add(panel, BorderLayout.CENTER);
frame.setVisible(true);
}
panel.removeAll();
for (int i = 0; i < samples.length; i++) {
panel.add(getImage(samples[i]));
}
frame.revalidate();
frame.pack();
}
private static JLabel getImage(INDArray tensor) {
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
for (int i = 0; i < 784; i++) {
int pixel = (int) (((tensor.getDouble(i) + 1) * 2) * 255);
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
return new JLabel(scaled);
}
public static void main(String[] args) throws Exception {
Supplier<MultiLayerNetwork> genSupplier = () -> {
return new MultiLayerNetwork(new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().nIn(latentDim).nOut(width / 2 * height / 2 * 128)
.activation(Activation.LEAKYRELU).weightInit(WeightInit.NORMAL).build())
.layer(1, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5)
.convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
// Up-sampling to 28x28x256
.layer(2, new Deconvolution2D.Builder().nIn(128).nOut(128).stride(2, 2)
.kernelSize(5, 5).convolutionMode(ConvolutionMode.Same)
.activation(Activation.LEAKYRELU).build())
.layer(3, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5)
.convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
.layer(4, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5)
.convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
.layer(5, new Convolution2D.Builder().nIn(128).nOut(channels).kernelSize(7, 7)
.convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
.layer(6, new ActivationLayer.Builder().activation(Activation.TANH).build())
.inputPreProcessor(1,
new FeedForwardToCnnPreProcessor(height / 2, width / 2, 128))
.inputPreProcessor(6, new CnnToFeedForwardPreProcessor(height, width, channels))
.setInputType(InputType.feedForward(latentDim))
.build());
};
GAN.DiscriminatorProvider discriminatorProvider = (updater) -> {
return new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
.updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
//.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
//.gradientNormalizationThreshold(100.0)
.list()
.layer(0, new Convolution2D.Builder().nIn(channels).nOut(64).kernelSize(3, 3)
.activation(Activation.LEAKYRELU).build())
.layer(1, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2)
.activation(Activation.LEAKYRELU).build())
.layer(2, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2)
.activation(Activation.LEAKYRELU).build())
.layer(3, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2)
.activation(Activation.LEAKYRELU).build())
.layer(4, new DropoutLayer.Builder().dropOut(0.5).build())
.layer(5, new DenseLayer.Builder().nIn(64 * 2 * 2).nOut(1).activation(Activation.SIGMOID).build())
.layer(6, new LossLayer.Builder().lossFunction(LossFunctions.LossFunction.XENT).build())
.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(height, width, channels))
.inputPreProcessor(4, new CnnToFeedForwardPreProcessor(2, 2, 64))
.setInputType(InputType.convolutionalFlat(height, width, channels))
.build());
};
GAN gan = new GAN.Builder()
.generator(genSupplier)
.discriminator(discriminatorProvider)
.latentDimension(latentDim)
//.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
//.gradientNormalizationThreshold(1.0)
.updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
.build();
gan.getGenerator().setListeners(new PerformanceListener(1, true));
gan.getDiscriminator().setListeners(new PerformanceListener(1, true));
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
int batchSize = 64;
MnistDataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 42);
for (int i = 0; i < 10; i++) {
//gan.fit(trainData, 1);
System.out.println("Starting epoch: " + (i + 1));
trainData.reset();
int j = 0;
while (trainData.hasNext()) {
DataSet next = trainData.next();
gan.fit(next);
if (j % 1 == 0) {
System.out.println("Epoch " + (i + 1) + " iteration " + j + " Visualizing...");
INDArray fakeIn = Nd4j.rand(new int[]{batchSize, latentDim});
INDArray[] samples = new INDArray[9];
for (int k = 0; k < 9; k++) {
samples[k] = gan.getGenerator().output(fakeIn.getRow(k), false);
}
visualize(samples);
}
j++;
}
System.out.println("Finished epoch: " + (i + 1));
}
}
}

View File

@ -0,0 +1,146 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
/**
* Relatively small GAN example using only Dense layers with dropout to generate handwritten
* digits from MNIST data.
*/
public class MnistSimpleGAN {
private static final int LATENT_DIM = 100;
private static final double LEARNING_RATE = 0.0002;
private static final IUpdater UPDATER_ZERO = Sgd.builder().learningRate(0.0).build();
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
public static MultiLayerNetwork getGenerator() {
MultiLayerConfiguration genConf = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.list()
.layer(new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
.layer(new DenseLayer.Builder().nIn(256).nOut(512).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
.layer(new DenseLayer.Builder().nIn(512).nOut(1024).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
.layer(new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build())
.build();
return new MultiLayerNetwork(genConf);
}
public static MultiLayerNetwork getDiscriminator(IUpdater updater) {
MultiLayerConfiguration discConf = new NeuralNetConfiguration.Builder()
.seed(42)
.updater(updater)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(1024).updater(updater).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
.layer(new DropoutLayer.Builder(1 - 0.5).build())
.layer(new DenseLayer.Builder().nIn(1024).nOut(512).updater(updater).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
.layer(new DropoutLayer.Builder(1 - 0.5).build())
.layer(new DenseLayer.Builder().nIn(512).nOut(256).updater(updater).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
.layer(new DropoutLayer.Builder(1 - 0.5).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1)
.activation(Activation.SIGMOID).updater(updater).build())
.build();
return new MultiLayerNetwork(discConf);
}
public static void main(String[] args) throws Exception {
GAN gan = new GAN.Builder()
.generator(MnistSimpleGAN::getGenerator)
.discriminator(MnistSimpleGAN::getDiscriminator)
.latentDimension(LATENT_DIM)
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.build();
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
int batchSize = 128;
MnistDataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 42);
// Sample from latent space once to visualize progress on image generation.
int numSamples = 9;
JFrame frame = GANVisualizationUtils.initFrame();
JPanel panel = GANVisualizationUtils.initPanel(frame, numSamples);
for (int i = 0; i < 100; i++) {
trainData.reset();
int j = 0;
while (trainData.hasNext()) {
gan.fit(trainData.next());
//gan.fit(trainData, 1);
if (j % 10 == 0) {
INDArray fakeIn = Nd4j.rand(new int[]{batchSize, LATENT_DIM});
System.out.println("Epoch " + (i + 1) + " Iteration " + j + " Visualizing...");
INDArray[] samples = new INDArray[numSamples];
for (int k = 0; k < numSamples; k++) {
INDArray input = fakeIn.getRow(k);
samples[k] = gan.getGenerator().output(input, false);
}
GANVisualizationUtils.visualize(samples, frame, panel);
}
j++;
}
}
}
}

View File

@ -20,39 +20,87 @@
package net.brutex.spark; package net.brutex.spark;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CreateFlag;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileContext;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Options.CreateOpts;
import org.apache.spark.SparkConf; import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import java.io.Serializable; import java.io.Serializable;
import org.junit.jupiter.api.Test;
@Slf4j @Slf4j
public abstract class BaseSparkSessionTest implements Serializable { public abstract class BaseSparkSessionTest implements Serializable {
private static SparkSession spark; private static SparkSession spark;
public static SparkSession getSession() { public static SparkSession getSession() {
final String jarPath = uploadToHdfs("./build/libs/brutex-extended-tests-1.0.0-SNAPSHOT-all.jar");
SparkConf sparkConf = new SparkConf() SparkConf sparkConf = new SparkConf()
.setMaster("spark://10.5.5.200:7077") .setMaster("spark://10.5.5.200:7077")
.setAppName(BaseSparkSessionTest.class.getSimpleName()) .setAppName(BaseSparkSessionTest.class.getSimpleName())
.set("spark.driver.bindAddress", "10.5.5.145") .set("spark.driver.bindAddress", "10.5.5.145")
.set("spark.blockManager.port", "65001")
//.set("spark.driver.bindAddress", "0.0.0.0")
.set("spark.network.timeout", "240000") .set("spark.network.timeout", "240000")
.set("spark.driver.host", "10.5.5.145") .set("spark.driver.host", "10.5.5.145")
.set("spark.deploy.mode", "client") .set("spark.deploy.mode", "cluster")
.set("spark.executor.memory", "4g") .set("spark.executor.memory", "4g")
.set("spark.cores.max", "4") .set("spark.cores.max", "4")
.set("spark.worker.cleanup.enabled", "true") .set("spark.worker.cleanup.enabled", "true")
.set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml") .set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
.set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml") .set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
.set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000"); .set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000")
//.set("spark.jars", jarPath)
;
spark = SparkSession.builder() spark = SparkSession.builder()
.config(sparkConf) .config(sparkConf)
.getOrCreate(); .getOrCreate();
spark.sparkContext().addJar(jarPath);
return spark; return spark;
} }
public static String uploadToHdfs(String jarFile) {
File f = new File(jarFile);
if(!f.exists() && !f.isFile()) throw new RuntimeException("File to upload does not exist.");
final String base = "hdfs://10.5.5.200:9000/";
String targetPath = "/user/brian/" + f.getName();
try {
Configuration conf = new Configuration();
//FileContext hdfs = FileContext.getFileContext(URI.create(base), conf);
org.apache.hadoop.fs.FileSystem hdfs = FileSystem.get(URI.create(base), conf);
//String file = SparkFiles.get("phpMawTba");
org.apache.hadoop.fs.Path target = new org.apache.hadoop.fs.Path(targetPath);
try {
hdfs.delete(target, false);
} catch (Exception e) {};
FileUtil.copy(f, hdfs, target, false, conf);
//Apache Commons
//FileUtils.copyFile(f, fTarget);
} catch(IOException ioe) {
ioe.printStackTrace();
}
return base + targetPath;
}
@BeforeAll @BeforeAll
public static void beforeAll() { public static void beforeAll() {
@ -64,4 +112,11 @@ public abstract class BaseSparkSessionTest implements Serializable {
getSession().close(); getSession().close();
} }
@Test
public void testSessionCreation() {
SparkSession session = getSession();
log.info("Spark {} session id: {}", session.version(), session.sessionUUID());
}
} }

View File

@ -20,22 +20,34 @@
*/ */
package net.brutex.spark; package net.brutex.spark;
import com.fasterxml.jackson.core.Version; import java.io.IOException;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.ForeachFunction;
import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.filter.FilterInvalidValues; import org.datavec.api.transform.filter.FilterInvalidValues;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.Writable; import org.datavec.api.Writable;
import org.datavec.spark.transform.Normalization;
import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.SparkTransformExecutor;
import org.datavec.spark.transform.misc.StringToWritablesFunction; import org.datavec.spark.transform.misc.StringToWritablesFunction;
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator.Set;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
@ -47,7 +59,6 @@ import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.datavec.DataVecDataSetFunction; import org.deeplearning4j.spark.datavec.DataVecDataSetFunction;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.deeplearning4j.ui.api.UIServer;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -56,7 +67,6 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.nio.file.Paths;
import java.util.Arrays; import java.util.Arrays;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
@ -70,24 +80,77 @@ import java.util.Random;
@Slf4j @Slf4j
@TestInstance(TestInstance.Lifecycle.PER_CLASS) @TestInstance(TestInstance.Lifecycle.PER_CLASS)
@Tag("integration") @Tag("integration")
public class BrianTest /*extends BaseDL4JTest*/ { public class BrianTest extends BaseSparkSessionTest {
/*
static { static {
String OS = System.getProperty("os.name").toLowerCase(); String OS = System.getProperty("os.name").toLowerCase();
if (OS.contains("win")) { if (OS.contains("win")) {
System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString()); System.setProperty("hadoop.home.dir",
Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
} else { } else {
System.setProperty("hadoop.home.dir", "/"); System.setProperty("hadoop.home.dir", "/");
} }
} }
*/
public long getTimeoutMilliseconds() {
return 400000L;
}
private JavaSparkContext sc; private JavaSparkContext sc;
private JavaRDD<String> rdd; private JavaRDD<String> rdd;
@Test
public void wrapEmnitDataset() throws IOException, InterruptedException {
SparkSession sc = getSession();
EmnistDataSetIterator dataset = new EmnistDataSetIterator(Set.BALANCED, 128, true);
DataSet ds = dataset.next();
System.out.println( "Number of features " + ds.numInputs());
System.out.println( "Number of samples " + ds.numExamples());
System.out.println( "Outcomes " + ds.numOutcomes());
final String oppsFile = uploadToHdfs("c:/temp/opps.csv");
//System.out.println( "Reading file from " + oppsFile);
JavaRDD<String> rdd = sc.sparkContext().textFile(oppsFile, 1)
.toJavaRDD();
System.out.println("Count " + rdd.count());
//while(true) Thread.sleep(1000);
//rdd.foreach( s -> {
// System.out.println("* "+s);
// });
//JavaRDD<String> rdd2 = rdd.flatMap( s -> Arrays.asList( s.split(";")).iterator() );
//rdd2.collect().forEach( a -> System.out.print("# " + a + " ") );
StructType struct = new StructType(Arrays.asList(
StructField.apply("stage", DataTypes.StringType, false, Metadata.empty()),
StructField.apply("period", DataTypes.StringType, false, Metadata.empty()),
StructField.apply("portfolio", DataTypes.StringType, false, Metadata.empty()),
StructField.apply("country", DataTypes.StringType, false, Metadata.empty()),
StructField.apply("lfr", DataTypes.StringType, false, Metadata.empty()),
StructField.apply("saas", DataTypes.StringType, false, Metadata.empty())
).toArray(new StructField[]{})
);
JavaRDD<Row> rdd3 = rdd.map( attributes -> RowFactory.create(attributes.split(";")));
Dataset<Row> frame = sc.createDataFrame(rdd3, struct);
Dataset<Row> frame2 = frame.select(frame.col("lfr").cast(DataTypes.FloatType));
frame.show(200);
// frame.collect().map(row -> System.out.println(row.fieldIndex("stage") + row.fieldIndex("country")));
//frame.agg( frame.col("stage"), frame.col("lfr"));
frame.foreach((ForeachFunction<Row>) s -> System.out.println(s));
//sc.read().csv(rdd2);
//Normalization normalization = Normalization.zeromeanUnitVariance()
//sc.
}
/* /*
@BeforeAll @BeforeAll
public void loadData() { public void loadData() {
@ -109,71 +172,6 @@ public class BrianTest /*extends BaseDL4JTest*/ {
} }
*/ */
@BeforeAll
public void setUp() throws Exception {
log.info("Running @BeforeEach scope");
System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
Version version = com.fasterxml.jackson.databind.cfg.PackageVersion.VERSION;
System.out.println("Jackson version found: " + version);
SparkConf sparkConf = new SparkConf()
.setMaster("spark://10.5.5.200:7077")
.setAppName("Brian3")
.set("spark.driver.bindAddress", "10.5.5.145")
.set("spark.network.timeout", "240000")
.set("spark.driver.host", "10.5.5.145")
.set("spark.driver.bindAddress", "10.5.5.145")
.set("spark.deploy.mode", "cluster")
.set("spark.executor.memory", "2g")
.set("spark.executor.cores", "2")
.set("spark.cores.max", "4")
.set("spark.worker.cleanup.enabled", "false")
.set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
.set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
.set("spark.driver.extraClassPath", "brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar;brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar")
.set("spark.executor.extraClassPath", "brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar;brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar")
.set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000");
//.set("spark.driver.cores", "2")
//.set("spark.driver.memory", "8g")
//.set("spark.driver.host", "10.5.5.145")
//.setExecutorEnv("spark.executor.cores", "2")
//.setExecutorEnv("spark.executor.memory", "2g")
//.set("spark.submit.deployMode", "client")
/*
SparkSession spark = SparkSession
.builder()
.master("spark://10.5.5.200:7077")
.config("spark.driver.bindAddress", "10.5.5.145")
.config("spark.driver.host", "10.5.5.145")
//.config("spark.driver.memory", "5g")
.appName("BrianTest2")
.getOrCreate();
*/
sc = new JavaSparkContext(sparkConf);
// sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\deeplearning4j\\deeplearning4j-scaleout\\spark\\dl4j-spark-nlp-java8\\target\\dl4j-spark-nlp-java8_2.12-1.0.0-SNAPSHOT-tests.jar");
// sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\datavec\\datavec-api\\target\\datavec-api-1.0.0-SNAPSHOT.jar");
// sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\nd4j\\nd4j-uberjar\\target\\nd4j-uberjar-1.0.0-SNAPSHOT.jar");
// sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\nd4j\\nd4j-common\\target\\nd4j-common-1.0.0-SNAPSHOT.jar");
// sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\datavec\\datavec-spark\\target\\datavec-spark_2.12-1.0.0-SNAPSHOT.jar");
sc.addJar("C:\\Users\\brian\\_projects\\Brian-Spark-DL4J-Tests\\target\\brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar");
sc.addJar("C:\\Users\\brian\\_projects\\Brian-Spark-DL4J-Tests\\target\\brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar");
rdd = sc.textFile("hdfs://10.5.5.200:9000/user/zeppelin/cities_full.csv.gz");
}
@AfterAll
public void tearDown() throws Exception {
sc.close();
sc.stop();
UIServer.stopInstance();
}
@Test @Test
////@Ignore("AB 2019/05/21 - Failing - Issue #7657") ////@Ignore("AB 2019/05/21 - Failing - Issue #7657")
@ -193,7 +191,6 @@ public class BrianTest /*extends BaseDL4JTest*/ {
@Test @Test
public void testSchemaCreation() throws Exception { public void testSchemaCreation() throws Exception {
rdd.cache(); rdd.cache();
JavaRDD<String> cities = rdd.map((Function<String, String>) line -> { JavaRDD<String> cities = rdd.map((Function<String, String>) line -> {
@ -208,7 +205,6 @@ public class BrianTest /*extends BaseDL4JTest*/ {
return line.split(",")[3]; return line.split(",")[3];
}).cache(); }).cache();
CSVRecordReader recordReader = new CSVRecordReader(0, ','); CSVRecordReader recordReader = new CSVRecordReader(0, ',');
JavaRDD<List<Writable>> convertedRDD = rdd.map((Function<String, List<Writable>>) s -> { JavaRDD<List<Writable>> convertedRDD = rdd.map((Function<String, List<Writable>>) s -> {
return new StringToWritablesFunction(recordReader).call(s); return new StringToWritablesFunction(recordReader).call(s);
@ -252,16 +248,18 @@ public class BrianTest /*extends BaseDL4JTest*/ {
processedData.cache(); processedData.cache();
//log.info("Datenmenge nach processing: " + processedData.count()); //log.info("Datenmenge nach processing: " + processedData.count());
//Vectorisieren //Vectorisieren
int labelIndex = 0; //in welcher Spalte ist das Label int labelIndex = 0; //in welcher Spalte ist das Label
int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte
DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels, false); DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels,
false);
JavaRDD<DataSet> rddDataSet = processedData.map(datavecFunction); JavaRDD<DataSet> rddDataSet = processedData.map(datavecFunction);
log.info("rddDataset: " + rddDataSet.toDebugString()); log.info("rddDataset: " + rddDataSet.toDebugString());
Random rand = new Random(); Random rand = new Random();
rddDataSet.sortBy( (Function<DataSet, Double>) s -> {return rand.nextDouble(); }, true, 8); rddDataSet.sortBy((Function<DataSet, Double>) s -> {
return rand.nextDouble();
}, true, 8);
//og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect()); //og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect());
@ -281,7 +279,8 @@ public class BrianTest /*extends BaseDL4JTest*/ {
//Create Trainingmaster //Create Trainingmaster
TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4) TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4)
.rddTrainingApproach(RDDTrainingApproach.Direct) //when "export", tries to save everything first .rddTrainingApproach(
RDDTrainingApproach.Direct) //when "export", tries to save everything first
.batchSizePerWorker(1000) .batchSizePerWorker(1000)
.collectTrainingStats(true) .collectTrainingStats(true)
.build(); .build();
@ -292,15 +291,18 @@ public class BrianTest /*extends BaseDL4JTest*/ {
.seed(123) .seed(123)
.updater(new Nesterovs(0.1, 0.9)) .updater(new Nesterovs(0.1, 0.9))
.list() .list()
.layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build()) .layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER)
.layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) .activation(Activation.RELU).l2(0.001).build())
.layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
//.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) //.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4).weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4)
.weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
.build(); .build();
//Define SparkNet //Define SparkNet
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration, trainingMaster); SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration,
trainingMaster);
JavaRDD<DataSet>[] split = rddDataSet.randomSplit(new double[]{0.9, 0.1}, 123); JavaRDD<DataSet>[] split = rddDataSet.randomSplit(new double[]{0.9, 0.1}, 123);
//JavaRDD<DataSet> trainingData = split[0]; //JavaRDD<DataSet> trainingData = split[0];

View File

@ -25,8 +25,8 @@ ext {
def flatbuffers = [version: "1.10.0"] def flatbuffers = [version: "1.10.0"]
def spark = [version: "3.1.2"] def spark = [version: "3.2.2"]
def scala = [version:"2.12.10"] //[version:"2.13.5"] def scala = [version:"2.12.15"] //[version:"2.13.5"]
def netty = [version: "4.1.68.Final"] def netty = [version: "4.1.68.Final"]

View File

@ -21,59 +21,44 @@
package net.brutex.cavis.dvec.api; package net.brutex.cavis.dvec.api;
import java.io.Serializable;
import java.nio.Buffer; import java.nio.Buffer;
import java.nio.LongBuffer; import java.nio.ByteBuffer;
import java.util.List;
import net.brutex.cavis.dvec.api.exceptions.DVecException; import net.brutex.cavis.dvec.api.exceptions.DVecException;
/** /**
* A Field can be considered a "column" in a {@code Record}, as such a Field may refer to multiple * Abtract implementation of the Field interface {@see FieldInterface}, that handles all data storage
* entries of that "column". Fields are typed as Buffers. Some of them defined in the dvec core api, * in memory and adds basic error handling.
* other (i.e. Image or Arrow) require dvec extensions accordingly.
* *
* @author Brian Rosenberger * @author Brian Rosenberger
* @since 1.0 * @since 1.0
*/ */
public interface FieldInterface<T extends Buffer> extends Serializable { public abstract class AbstractField<T extends Buffer> implements Field<T> {
/** /**
* Get a reference to the metadata for this Field. * {@inheritDoc}
*
* @return the {@link FieldMetadata}
*/
FieldMetadata getFieldMetadata();
/**
* Get the 1st field as Buffer. This deserializes the data from the underlying storage.
*
* @return T underlying Buffer
*/
default T read() throws DVecException {
return read(0, 1);
}
/**
* Get a range of fields as a {@code Buffer}
* *
* @param start Index of starting position, zero based * @param start Index of starting position, zero based
* @param length how many fields to read * @param length how many fields to read
* @return the buffers * @return the list of Buffer
*/ */
T read(long start, long length) throws DVecException; @Override
public T read(long start, long length) throws DVecException {
/** if (start<0 || start>internalStorage.capacity()-1 ) {
* Write the data into the underlying storage. throw new DVecException("Read on Field start position is out of bounds.");
*/ }
default void write(T buffer) { if (start+length> internalStorage.capacity()) {
write(0, buffer); throw new DVecException("Read on Field exceeds field length");
}
return null;
} }
/** @Override
* Write the data into the underyling storage starting at a position public void write(long pos, T buffer) {
*
* @param pos the position to start }
*/
void write(long pos, T buffer); private ByteBuffer internalStorage = null;
} }

View File

@ -21,46 +21,57 @@
package net.brutex.cavis.dvec.api; package net.brutex.cavis.dvec.api;
import java.io.Serializable;
import java.nio.Buffer; import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import net.brutex.cavis.dvec.api.exceptions.DVecException; import net.brutex.cavis.dvec.api.exceptions.DVecException;
/** /**
* Abtract implementation of the Field interface {@see FieldInterface}, that handles all data storage * A Field can be considered a "column" in a {@code Record}, as such a Field may refer to multiple
* in memory and adds basic error handling. * entries of that "column". Fields are typed as Buffers. Some of them defined in the dvec core api,
* other (i.e. Image or Arrow) require dvec extensions accordingly.
* *
* @author Brian Rosenberger * @author Brian Rosenberger
* @since 1.0 * @since 1.0
*/ */
public abstract class Field<T extends Buffer> implements FieldInterface<T> { public interface Field<T extends Buffer> extends Serializable {
/** /**
* {@inheritDoc} * Get a reference to the metadata for this Field.
*
* @return the {@link FieldMetadata}
*/
FieldMetadata getFieldMetadata();
/**
* Get the 1st field as Buffer. This deserializes the data from the underlying storage.
*
* @return T underlying Buffer
*/
default T read() throws DVecException {
return read(0, 1);
}
/**
* Get a range of fields as a {@code Buffer}
* *
* @param start Index of starting position, zero based * @param start Index of starting position, zero based
* @param length how many fields to read * @param length how many fields to read
* @return the list of Buffer * @return the buffers
*/ */
@Override T read(long start, long length) throws DVecException;
public T read(long start, long length) throws DVecException {
if (start<0 || start>internalStorage.capacity()-1 ) { /**
throw new DVecException("Read on Field start position is out of bounds."); * Write the data into the underlying storage.
} */
if (start+length> internalStorage.capacity()) { default void write(T buffer) {
throw new DVecException("Read on Field exceeds field length"); write(0, buffer);
}
return null;
} }
@Override /**
public void write(long pos, T buffer) { * Write the data into the underyling storage starting at a position
*
} * @param pos the position to start
*/
private ByteBuffer internalStorage = null; void write(long pos, T buffer);
} }

View File

@ -4877,7 +4877,7 @@ public class Nd4j {
* Create an ndarray of zeros * Create an ndarray of zeros
* *
* @param shape the shape of the array * @param shape the shape of the array
* @return an ndarray with ones filled in * @return an ndarray with zeros filled in
*/ */
public static INDArray zeros(int[] shape, char order) { public static INDArray zeros(int[] shape, char order) {
checkShapeValues(shape); checkShapeValues(shape);
@ -4896,7 +4896,7 @@ public class Nd4j {
* Create an ndarray of zeros * Create an ndarray of zeros
* *
* @param shape the shape of the array * @param shape the shape of the array
* @return an ndarray with ones filled in * @return an ndarray with zeros filled in
*/ */
public static INDArray zeros(@NonNull int... shape) { public static INDArray zeros(@NonNull int... shape) {
return Nd4j.create(shape); return Nd4j.create(shape);
@ -4907,7 +4907,7 @@ public class Nd4j {
* Create an ndarray of zeros * Create an ndarray of zeros
* *
* @param shape the shape of the array * @param shape the shape of the array
* @return an ndarray with ones filled in * @return an ndarray with zeros filled in
*/ */
public static INDArray zeros(@NonNull long... shape) { public static INDArray zeros(@NonNull long... shape) {
return Nd4j.create(shape); return Nd4j.create(shape);

View File

@ -99,9 +99,14 @@ public class DL4JClassLoading {
.asSubclass(superclass) .asSubclass(superclass)
.getDeclaredConstructor(parameterTypes) .getDeclaredConstructor(parameterTypes)
.newInstance(args); .newInstance(args);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException } catch (InstantiationException | IllegalAccessException
| NoSuchMethodException instantiationException) { | NoSuchMethodException instantiationException) {
log.error(String.format("Cannot create instance of class '%s'.", className), instantiationException); log.error(String.format("Cannot create instance of class '%s'.", className), instantiationException);
throw new RuntimeException(instantiationException);
} catch (InvocationTargetException instantiationException) {
log.error(String.format("---------- ----------- ---------- \nInvocationTargetException was '%s'.", instantiationException.getTargetException().getMessage()), instantiationException);
log.error(String.format("java.library.path was '%s'\n---------- ---------- ----------", System.getProperty("java.library.path")));
throw new RuntimeException(instantiationException); throw new RuntimeException(instantiationException);
} }
} }

View File

@ -0,0 +1,23 @@
apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
ext {
buildTarget = rootProject.ext.buildTarget
}
dependencies {
implementation platform(projects.cavisCommonPlatform)
implementation project(":cavis-native:cavis-native-jcublas")
implementation projects.cavisDnn.cavisDnnApi
implementation projects.cavisDnn.cavisDnnNn
implementation group: "org.bytedeco", name: "cuda"
implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget
implementation group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
implementation group: "org.bytedeco", name: "javacpp"
implementation group: "org.bytedeco", name: "javacpp", classifier: buildTarget
implementation 'com.jakewharton.byteunits:byteunits:0.9.1'
}

View File

@ -0,0 +1,252 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.cuda;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.*;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.bytedeco.cuda.cudnn.*;
import static org.bytedeco.cuda.global.cudart.*;
import static org.bytedeco.cuda.global.cudnn.*;
/**
* Functionality shared by all cuDNN-based helpers.
*
* @author saudet
*/
@Slf4j
public abstract class BaseCudnnHelper {
/* public BaseCudnnHelper() {
}
*/
protected static void checkCuda(int error) {
if (error != cudaSuccess) {
throw new RuntimeException("CUDA error = " + error + ": " + cudaGetErrorString(error).getString());
}
}
protected static void checkCudnn(int status) {
if (status != CUDNN_STATUS_SUCCESS) {
throw new RuntimeException("cuDNN status = " + status + ": " + cudnnGetErrorString(status).getString());
}
}
protected static class CudnnContext extends cudnnContext {
protected static class Deallocator extends CudnnContext implements Pointer.Deallocator {
Deallocator(CudnnContext c) {
super(c);
}
@Override
public void deallocate() {
destroyHandles();
}
}
public CudnnContext() {
// insure that cuDNN initializes on the same device as ND4J for this thread
Nd4j.create(1);
AtomicAllocator.getInstance();
// This needs to be called in subclasses:
// createHandles();
// deallocator(new Deallocator(this));
}
public CudnnContext(CudnnContext c) {
super(c);
}
protected void createHandles() {
checkCudnn(cudnnCreate(this));
}
protected void destroyHandles() {
checkCudnn(cudnnDestroy(this));
}
}
protected static class DataCache extends Pointer {
static class Deallocator extends DataCache implements Pointer.Deallocator {
Deallocator(DataCache c) {
super(c);
}
@Override
public void deallocate() {
checkCuda(cudaFree(this));
setNull();
}
}
static class HostDeallocator extends DataCache implements Pointer.Deallocator {
HostDeallocator(DataCache c) {
super(c);
}
@Override
public void deallocate() {
checkCuda(cudaFreeHost(this));
setNull();
}
}
public DataCache() {}
public DataCache(long size) {
position = 0;
limit = capacity = size;
int error = cudaMalloc(this, size);
if (error != cudaSuccess) {
log.warn("Cannot allocate " + size + " bytes of device memory (CUDA error = " + error
+ "), proceeding with host memory");
checkCuda(cudaMallocHost(this, size));
deallocator(new HostDeallocator(this));
} else {
deallocator(new Deallocator(this));
}
}
public DataCache(DataCache c) {
super(c);
}
}
protected static class TensorArray extends PointerPointer<cudnnTensorStruct> {
static class Deallocator extends TensorArray implements Pointer.Deallocator {
Pointer owner;
Deallocator(TensorArray a, Pointer owner) {
this.address = a.address;
this.capacity = a.capacity;
this.owner = owner;
}
@Override
public void deallocate() {
for (int i = 0; !isNull() && i < capacity; i++) {
cudnnTensorStruct t = this.get(cudnnTensorStruct.class, i);
checkCudnn(cudnnDestroyTensorDescriptor(t));
}
if (owner != null) {
owner.deallocate();
owner = null;
}
setNull();
}
}
public TensorArray() {}
public TensorArray(long size) {
PointerPointer p = new PointerPointer(size);
p.deallocate(false);
this.address = p.address();
this.limit = p.limit();
this.capacity = p.capacity();
cudnnTensorStruct t = new cudnnTensorStruct();
for (int i = 0; i < capacity; i++) {
checkCudnn(cudnnCreateTensorDescriptor(t));
this.put(i, t);
}
deallocator(new Deallocator(this, p));
}
public TensorArray(TensorArray a) {
super(a);
}
}
protected final DataType nd4jDataType;
protected final int dataType;
protected final int dataTypeSize;
// both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta
protected final Pointer alpha;
protected final Pointer beta;
protected SizeTPointer sizeInBytes = new SizeTPointer(1);
public BaseCudnnHelper(@NonNull DataType dataType){
this.nd4jDataType = dataType;
this.dataType = dataType == DataType.DOUBLE ? CUDNN_DATA_DOUBLE
: dataType == DataType.FLOAT ? CUDNN_DATA_FLOAT : CUDNN_DATA_HALF;
this.dataTypeSize = dataType == DataType.DOUBLE ? 8 : dataType == DataType.FLOAT ? 4 : 2;
// both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta
this.alpha = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(1.0) : new FloatPointer(1.0f);
this.beta = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(0.0) : new FloatPointer(0.0f);
}
public static int toCudnnDataType(DataType type){
switch (type){
case DOUBLE:
return CUDNN_DATA_DOUBLE;
case FLOAT:
return CUDNN_DATA_FLOAT;
case INT:
return CUDNN_DATA_INT32;
case HALF:
return CUDNN_DATA_HALF;
default:
throw new RuntimeException("Cannot convert type: " + type);
}
}
public boolean checkSupported() {
// add general checks here, if any
return true;
}
/**
* From CuDNN documentation -
* "Tensors are restricted to having at least 4 dimensions... When working with lower dimensional data, it is
* recommended that the user create a 4Dtensor, and set the size along unused dimensions to 1."
*
* This method implements that - basically appends 1s to the end (shape or stride) to make it length 4,
* or leaves it unmodified if the length is already 4 or more.
* This method can be used for both shape and strides
*
* @param shapeOrStrides
* @return
*/
protected static int[] adaptForTensorDescr(int[] shapeOrStrides){
if(shapeOrStrides.length >= 4)
return shapeOrStrides;
int[] out = new int[4];
int i=0;
for(; i<shapeOrStrides.length; i++ ){
out[i] = shapeOrStrides[i];
}
for(; i<4; i++ ){
out[i] = 1;
}
return out;
}
}

View File

@ -0,0 +1,758 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.cuda.convolution;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import com.jakewharton.byteunits.BinaryByteUnit;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdFilterAlgo;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.FwdAlgo;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.common.util.OneTimeLogger;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.cudnn.*;
import static org.bytedeco.cuda.global.cudnn.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
/**
* cuDNN-based helper for the convolution layer.
*
* @author saudet
*/
@Slf4j
public class CudnnConvolutionHelper extends BaseCudnnHelper implements ConvolutionHelper {
public CudnnConvolutionHelper(DataType dataType) {
super(dataType);
}
private static class CudnnConvolutionContext extends CudnnContext {
private static class Deallocator extends CudnnConvolutionContext implements Pointer.Deallocator {
Deallocator(CudnnConvolutionContext c) {
super(c);
}
@Override
public void deallocate() {
destroyHandles();
}
}
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct();
private cudnnFilterStruct filterDesc = new cudnnFilterStruct();
private cudnnConvolutionStruct convDesc = new cudnnConvolutionStruct();
private cudnnActivationStruct activationDesc = new cudnnActivationStruct();
public CudnnConvolutionContext() {
createHandles();
deallocator(new Deallocator(this));
}
public CudnnConvolutionContext(CudnnConvolutionContext c) {
super(c);
srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
biasTensorDesc = new cudnnTensorStruct(c.biasTensorDesc);
deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
filterDesc = new cudnnFilterStruct(c.filterDesc);
convDesc = new cudnnConvolutionStruct(c.convDesc);
activationDesc = new cudnnActivationStruct(c.activationDesc);
}
@Override
protected void createHandles() {
super.createHandles();
checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(biasTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
checkCudnn(cudnnCreateFilterDescriptor(filterDesc));
checkCudnn(cudnnCreateConvolutionDescriptor(convDesc));
checkCudnn(cudnnCreateActivationDescriptor(activationDesc));
}
@Override
protected void destroyHandles() {
checkCudnn(cudnnDestroyActivationDescriptor(activationDesc));
checkCudnn(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnn(cudnnDestroyFilterDescriptor(filterDesc));
checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(biasTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
super.destroyHandles();
}
}
private CudnnConvolutionContext cudnnContext = new CudnnConvolutionContext();
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel,
int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn,
AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo,
ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
//Therefore: all computation here is done in NCHW format only
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
boolean origNHWC = false;
if(format == CNN2DFormat.NHWC){
input = input.permute(0,3,1,2); //NHWC to NCHW
delta = delta.permute(0,3,1,2);
origNHWC = true;
}
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
int code;
val miniBatch = input.size(0);
val outDepth = weights.size(0);
val inDepth = weights.size(1);
val kH = weights.size(2);
val kW = weights.size(3);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
input = args.getInput();
val inH = input.size(2);
val inW = input.size(3);
val srcStride = input.stride();
val outSize = args.getOutSize();
val outH = outSize[0];
val outW = outSize[1];
if (!Shape.strideDescendingCAscendingF(delta)) {
// apparently not supported by cuDNN
delta = delta.dup();
}
val deltaStride = delta.stride();
int[] algo1 = new int[1];
int[] algo2 = new int[1];
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
code = cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth,(int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]);
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) outDepth, (int) outH, (int) outW,
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]);
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW);
checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
if (mode == AlgoMode.USER_SPECIFIED && bwdFilterAlgo != null && bwdDataAlgo != null) {
switch (bwdFilterAlgo) {
case ALGO_0:
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
break;
case ALGO_1:
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
break;
case FFT:
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT;
break;
case ALGO_3:
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3;
break;
case WINOGRAD:
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD;
break;
case WINOGRAD_NONFUSED:
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED;
break;
case FFT_TILING:
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING;
break;
case COUNT:
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
break;
default:
throw new IllegalArgumentException("Unknown BwdFilterAlgo: " + bwdFilterAlgo);
}
switch (bwdDataAlgo) {
case ALGO_0:
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
break;
case ALGO_1:
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
break;
case FFT:
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT;
break;
case FFT_TILING:
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING;
break;
case WINOGRAD:
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD;
break;
case WINOGRAD_NONFUSED:
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED;
break;
case COUNT:
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
break;
default:
throw new IllegalArgumentException("Unknown BwdDataAlgo: " + bwdDataAlgo);
}
} else {
/*
code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
: CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
0, algo1);
*/
val fa = new cudnnConvolutionBwdFilterAlgoPerf_t();
val counts = new int[1];
code = cudnnFindConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, 1, counts, fa);
algo1[0] = fa.algo();
checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
/*
code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
: CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
0, algo2);
*/
val da = new cudnnConvolutionBwdDataAlgoPerf_t();
code = cudnnFindConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, 1, counts, da);
algo2[0] = da.algo();
checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
}
if(log.isTraceEnabled()){
BwdFilterAlgo fa = BwdFilterAlgo.values()[algo1[0]];
BwdDataAlgo da = BwdDataAlgo.values()[algo2[0]];
log.trace("CudnnConvolutionHelper backward algorithm selection: mode {}, filter algorithm {}, data algorithm {}", mode, fa, da);
}
INDArray epsNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, weights.dataType(), new long[] {(int) miniBatch,(int) inDepth, (int) inH, (int) inW}, 'c');
val dstStride = epsNext.stride();
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView,
biasGradView, delta, epsNext);
Pointer srcData = allocator.getPointer(input, context);
Pointer filterData = allocator.getPointer(weights, context);
Pointer filterGradData = allocator.getPointer(weightGradView, context);
Pointer biasGradData = allocator.getPointer(biasGradView, context);
Pointer deltaData = allocator.getPointer(delta, context);
Pointer dstData = allocator.getPointer(epsNext, context);
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
checkCudnn(false, "cudnnSetStream", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
(int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]);
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0],
sizeInBytes);
checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
long sizeInBytes1 = sizeInBytes.get(0);
code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0],
sizeInBytes);
checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
long sizeInBytes2 = sizeInBytes.get(0);
if (workSpace == null || sizeInBytes1 > workSpace.capacity() || sizeInBytes2 > workSpace.capacity()) {
long newSize = Math.max(sizeInBytes1, sizeInBytes2);
if(log.isTraceEnabled()){
if(workSpace == null){
log.trace("CudnnConvolutionHelper backpropGradient: Allocating initial workspace of size {} ({})", newSize,
BinaryByteUnit.format(newSize, "#.00"));
} else {
log.trace("CudnnConvolutionHelper backpropGradient: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})",
workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"),
newSize, BinaryByteUnit.format(newSize, "#.00"));
}
}
if(workSpace != null)
workSpace.deallocate();
workSpace = new DataCache(newSize);
workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
}
code = cudnnSetTensor4dDescriptor(cudnnContext.biasTensorDesc, TENSOR_FORMAT, dataType, 1, (int) outDepth, 1, 1);
checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta,
cudnnContext.biasTensorDesc, biasGradData);
checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace,
workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData);
checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData,
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace,
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView,
delta, epsNext);
Gradient retGradient = new DefaultGradient();
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView);
retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c');
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
//Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
// we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
if(args.isManualPadBottom() || args.isManualPadRight()) {
epsNext = epsNext.get(all(), all(),
interval(0, epsNext.size(2) - (args.isManualPadBottom() ? 1 : 0)),
interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0)));
}
if(origNHWC){
epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
}
return new Pair<>(retGradient, epsNext);
}
@Override
public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad,
AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format,
LayerWorkspaceMgr workspaceMgr) {
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
//Therefore: all computation here is done in NCHW format only
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
boolean origNHWC = false;
if(format == CNN2DFormat.NHWC){
input = input.permute(0,3,1,2); //NHWC to NCHW
origNHWC = true;
}
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
int code;
val miniBatch = input.size(0);
val outDepth = weights.size(0);
val inDepth = weights.size(1);
val kH = weights.size(2);
val kW = weights.size(3);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
input = args.getInput();
val inH = input.size(2);
val inW = input.size(3);
val srcStride = input.stride();
val outSize = args.getOutSize();
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
INDArray z = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, weights.dataType(), new long[] {(int) miniBatch, (int) outDepth, outSize[0], outSize[1]});
code = cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]);
checkCudnn(true, "cudnnSetTensor4dDescriptorEx", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW);
checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
// find dimension of convolution output
// checkCudnn(cudnnGetConvolution2dForwardOutputDim(cudnnContext.convDesc, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, n, c, h, w));
// INDArray z = Nd4j.createUninitialized(new int[]{n[0],c[0],h[0],w[0]},'c');
int[] algo = new int[1];
val dstStride = z.stride();
code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) outDepth, (int) outSize[0],
(int) outSize[1], (int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]);
checkCudnn(true, "cudnnSetTensor4dDescriptorEx", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
if (mode == AlgoMode.USER_SPECIFIED && fwdAlgo != null) {
switch (fwdAlgo) {
case IMPLICIT_GEMM:
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
break;
case IMPLICIT_PRECOMP_GEMM:
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
break;
case GEMM:
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_GEMM;
break;
case DIRECT:
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_DIRECT;
break;
case FFT:
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_FFT;
break;
case FFT_TILING:
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING;
break;
case WINOGRAD:
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD;
break;
case WINOGRAD_NONFUSED:
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED;
break;
case COUNT:
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
break;
default:
throw new IllegalArgumentException("Unknown FwdAlgo: " + fwdAlgo);
}
} else {
/*
code = cudnnGetConvolutionForwardAlgorithm_v7(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.filterDesc, cudnnContext.convDesc,
cudnnContext.dstTensorDesc, mode == AlgoMode.NO_WORKSPACE
? CUDNN_CONVOLUTION_FWD_ : CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
0, algo);
*/
val cdf = new cudnnConvolutionFwdAlgoPerf_t();
val count = new int[1];
code = cudnnFindConvolutionForwardAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, 1, count, cdf);
if(code != CUDNN_STATUS_SUCCESS){
//If CuDNN can't infer algorithm - try IMPLICIT_GEMM
//Why this specifically? According to the docs, it seems to have the least number of restrictions
// to things like dilation
OneTimeLogger.warn(log, "Error getting CuDNN forward algorithm - falling back on IMPLICIT_GEMM");
mode = AlgoMode.USER_SPECIFIED;
fwdAlgo = FwdAlgo.IMPLICIT_GEMM;
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
}
algo[0] = cdf.algo();
}
if(log.isTraceEnabled()){
FwdAlgo a = FwdAlgo.values()[algo[0]];
log.trace("CudnnConvolutionHelper forward algorithm selection: mode {}, algorithm {}", mode, a);
}
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(z, input, weights, bias);
Pointer srcData = allocator.getPointer(input, context);
Pointer filterData = allocator.getPointer(weights, context);
Pointer biasData = allocator.getPointer(bias, context);
Pointer dstData = allocator.getPointer(z, context);
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0],
sizeInBytes);
checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
if (workSpace == null || sizeInBytes.get(0) > workSpace.capacity()) {
if(log.isTraceEnabled()){
if(workSpace == null){
log.trace("CudnnConvolutionHelper preOutput: allocating initial workspace of size {} ({})",
sizeInBytes.get(), BinaryByteUnit.format(sizeInBytes.get(), "#.00"));
} else {
log.trace("CudnnConvolutionHelper preOutput: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})",
workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"),
sizeInBytes.get(), BinaryByteUnit.format(sizeInBytes.get(), "#.00"));
}
}
if(workSpace != null)
workSpace.deallocate();
workSpace = new DataCache(sizeInBytes.get(0));
workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
}
code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace,
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnSetTensor4dDescriptor(cudnnContext.biasTensorDesc, TENSOR_FORMAT, dataType, 1, (int) outDepth, 1, 1);
checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha,
cudnnContext.dstTensorDesc, dstData);
checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
allocator.registerAction(context, z, input, weights, bias);
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
if(origNHWC){
z = z.permute(0,2,3,1); //NCHW to NHWC
}
return z;
}
private void checkCudnn(boolean forward, String step, int code, INDArray input, INDArray weights, INDArray bias, INDArray delta,
int[] kernel, int[] strides, int[] pad,
AlgoMode mode, FwdAlgo fwdAlgo, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, int[] dilation) {
if (code != CUDNN_STATUS_SUCCESS) {
StringBuilder sb = new StringBuilder();
sb.append("CuDNN error = ").append(code).append(": ").append(cudnnGetErrorString(code).getString())
.append(" during ")
.append(forward ? "forward pass" : "backward pass")
.append(" - step ").append(step)
.append(": inputShape=").append(Arrays.toString(input.shape()))
.append(", weightsShape=").append(Arrays.toString(weights.shape()))
.append(", biasShape=").append(bias == null ? null : Arrays.toString(bias.shape()));
if (!forward) {
sb.append(", gradientShape=").append(Arrays.toString(delta.shape()));
}
sb.append(", kernel=").append(Arrays.toString(kernel))
.append(", stride=").append(Arrays.toString(strides))
.append(", padding=").append(Arrays.toString(pad))
.append(", dilation=").append(Arrays.toString(dilation))
.append(", AlgoMode=").append(mode);
if (forward) {
sb.append(", fwdAlgo=").append(fwdAlgo);
} else {
sb.append(", bwdFilterAlgo=").append(bwdFilterAlgo)
.append(", bwdDataAlgo=").append(bwdDataAlgo);
}
sb.append(", convolutionMode=").append(convolutionMode);
throw new RuntimeException(sb.toString());
}
}
@Override
public INDArray activate(INDArray z, IActivation afn, boolean training) {
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
INDArray activation = z;
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(z);
Pointer dstData = allocator.getPointer(z, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
switch (afn.toString()) {
case "identity":
break;
case "sigmoid":
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID,
CUDNN_PROPAGATE_NAN, 0));
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break;
case "relu":
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU,
CUDNN_PROPAGATE_NAN, 0));
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break;
case "tanh":
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH,
CUDNN_PROPAGATE_NAN, 0));
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break;
case "softmax":
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break;
case "logsoftmax":
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break;
default:
activation = null;
}
allocator.registerAction(context, activation);
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
return activation;
}
/**
* @param poolingType Used when preparing data for subsampling layers ONLY. Null for convolution layers
* @return
*/
public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation,
ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat format){
INDArray origInput = input;
//Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides
// are non-default for C order - even if they *should* be OK otherwise
if(input.isView() || !Shape.hasDefaultStridesForShape(input)){
input = input.dup('c');
}
boolean nchw = format == CNN2DFormat.NCHW;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
val inH = input.size(hIdx);
val inW = input.size(wIdx);
boolean manualPadBottom = false;
boolean manualPadRight = false;
int[] outSize;
if (convolutionMode == ConvolutionMode.Same) {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
if(!Arrays.equals(padding, padBottomRight)){
/*
CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
padding) - padding can *only* be specified as the same amount for both the top/bottom, and for left/right.
In SAME mode padding, sometimes these are the same - but often they are not.
Note that when they differ, the bottom or right padding will be exactly 1 more than the top or left padding.
As per TF, we'll manually pad here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/conv_ops.cc#L571-L607
*/
manualPadBottom = (padding[0] != padBottomRight[0]);
manualPadRight = (padding[1] != padBottomRight[1]);
//NCHW format
long[] newShape;
if(nchw){
newShape = new long[]{input.size(0), input.size(1),
input.size(2) + (manualPadBottom ? 1 : 0),
input.size(3) + (manualPadRight ? 1 : 0)};
} else {
newShape = new long[]{input.size(0),
input.size(1) + (manualPadBottom ? 1 : 0),
input.size(2) + (manualPadRight ? 1 : 0),
input.size(3)};
}
INDArray newInput;
if(poolingType == null || poolingType != PoolingType.MAX){
newInput = Nd4j.create(input.dataType(), newShape);
} else {
//For max pooling, we don't want to include the padding in the maximum values. But, CuDNN doesn't knowm
// that these values are padding and hence should be excluded. Instead: We'll use -infinity so that,
// if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value
newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType());
}
if(nchw){
newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)),
interval(0, input.size(3))}, input);
} else {
newInput.put(new INDArrayIndex[]{all(), interval(0,input.size(1)),
interval(0, input.size(2)), all()}, input);
}
input = newInput;
//Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we
// now have the same amount of padding required for top/bottom, and left/right - which we'll let
// CuDNN handle
}
} else {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
}
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
}
@AllArgsConstructor
@Data
public static class CudnnForwardArgs {
private boolean manualPadBottom;
private boolean manualPadRight;
private INDArray input;
private INDArray origInput;
private int[] padding;
private int[] outSize;
}
@Override
public Map<String, Long> helperMemoryUse() {
//No memory use other than shared, and the structs (which are small)
return Collections.emptyMap();
}
}

View File

@ -0,0 +1,308 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.cuda.convolution.subsampling;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.nn.workspace.ArrayType;
import java.util.Collections;
import java.util.Map;
import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.cudnn.*;
import static org.bytedeco.cuda.global.cudnn.*;
import static org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper.getCudnnForwardArgs;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
/**
* cuDNN-based helper for the subsampling layer.
*
* @author saudet
*/
@Slf4j
public class CudnnSubsamplingHelper extends BaseCudnnHelper implements SubsamplingHelper {
public CudnnSubsamplingHelper(DataType dataType) {
super(dataType);
}
private static class CudnnSubsamplingContext extends CudnnContext {
private static class Deallocator extends CudnnSubsamplingContext implements Pointer.Deallocator {
Deallocator(CudnnSubsamplingContext c) {
super(c);
}
@Override
public void deallocate() {
destroyHandles();
}
}
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
deltaTensorDesc = new cudnnTensorStruct();
private cudnnPoolingStruct poolingDesc = new cudnnPoolingStruct();
public CudnnSubsamplingContext() {
createHandles();
deallocator(new Deallocator(this));
}
public CudnnSubsamplingContext(CudnnSubsamplingContext c) {
super(c);
srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
poolingDesc = new cudnnPoolingStruct(c.poolingDesc);
}
@Override
protected void createHandles() {
super.createHandles();
checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
checkCudnn(cudnnCreatePoolingDescriptor(poolingDesc));
}
@Override
protected void destroyHandles() {
checkCudnn(cudnnDestroyPoolingDescriptor(poolingDesc));
checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
super.destroyHandles();
}
}
private CudnnSubsamplingContext cudnnContext = new CudnnSubsamplingContext();
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides,
int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode,
int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
if(dilation[0] != 1 || dilation[1] != 1){
//CuDNN doesn't support dilated subsampling
return null;
}
boolean nchw = format == CNN2DFormat.NCHW;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
//We require the output as one of the arguments for backprop here
//TODO we could add cache mode support here somehow...
INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr);
val miniBatch = input.size(0);
val depth = input.size(chIdx);
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
input = args.getInput();
val inH = input.size(hIdx);
val inW = input.size(wIdx);
val srcStride = input.stride();
int[] outSize = args.getOutSize();
int outH = outSize[0];
int outW = outSize[1];
//subsampling doesn't have weights and thus gradients are not calculated for this layer
//only scale and reshape epsilon
Gradient retGradient = new DefaultGradient();
//Epsilons in shape: [miniBatch, channels, outH, outW]
//Epsilons out shape: [miniBatch, channels, inH, inW]
int poolingMode;
switch (poolingType) {
case AVG:
poolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
break;
case MAX:
poolingMode = CUDNN_POOLING_MAX;
break;
default:
return null;
}
if (!Shape.hasDefaultStridesForShape(epsilon) || epsilon.isView()) {
// apparently not supported by cuDNN
epsilon = epsilon.dup('c');
}
input = input.dup();
val deltaStride = epsilon.stride();
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW,
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
kernel[1], pad[0], pad[1], strides[0], strides[1]));
long[] outEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c');
val dstStride = outEpsilon.stride();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon);
Pointer srcData = allocator.getPointer(input, context);
Pointer epsData = allocator.getPointer(epsilon, context);
Pointer zData = allocator.getPointer(reduced, context);
Pointer dstData = allocator.getPointer(outEpsilon, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnPoolingBackward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.deltaTensorDesc,
zData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.srcTensorDesc, srcData, beta,
cudnnContext.dstTensorDesc, dstData));
allocator.registerAction(context, outEpsilon, input, epsilon, reduced);
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
//Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
// we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
if(args.isManualPadBottom() || args.isManualPadRight()) {
if(nchw){
outEpsilon = outEpsilon.get(all(), all(),
interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)),
interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0)));
} else {
outEpsilon = outEpsilon.get(all(),
interval(0, outEpsilon.size(1) - (args.isManualPadBottom() ? 1 : 0)),
interval(0, outEpsilon.size(2) - (args.isManualPadRight() ? 1 : 0)),
all());
}
}
return new Pair<>(retGradient, outEpsilon);
}
@Override
public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad,
PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
if(dilation[0] != 1 || dilation[1] != 1){
//CuDNN doesn't support dilated subsampling
return null;
}
boolean nchw = format == CNN2DFormat.NCHW;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
val miniBatch = input.size(0);
val inDepth = input.size(nchw ? 1 : 3);
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
input = args.getInput();
val inH = input.size(nchw ? 2 : 1);
val inW = input.size(nchw ? 3 : 2);
val srcStride = input.stride();
val outSize = args.getOutSize();
int outH = outSize[0];
int outW = outSize[1];
int poolingMode;
switch (poolingType) {
case AVG:
poolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
break;
case MAX:
poolingMode = CUDNN_POOLING_MAX;
break;
default:
return null;
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
kernel[1], pad[0], pad[1], strides[0], strides[1]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
long[] outShape = nchw ? new long[] {miniBatch, inDepth, outH, outW} : new long[] {miniBatch, outH, outW, inDepth};
INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c');
val dstStride = reduced.stride();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW,
(int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(input, reduced);
Pointer srcData = allocator.getPointer(input, context);
Pointer dstData = allocator.getPointer(reduced, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnPoolingForward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.srcTensorDesc,
srcData, beta, cudnnContext.dstTensorDesc, dstData));
allocator.registerAction(context, reduced, input);
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
return reduced;
}
@Override
public Map<String, Long> helperMemoryUse() {
//No persistent memory use other than the structs (which are small)
return Collections.emptyMap();
}
}

View File

@ -0,0 +1,245 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.cuda.dropout;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import com.jakewharton.byteunits.BinaryByteUnit;
import org.bytedeco.javacpp.*;
import org.deeplearning4j.nn.conf.dropout.DropoutHelper;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.common.util.ArrayUtil;
import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.cudnn.*;
import java.util.Collections;
import java.util.Map;
import static org.bytedeco.cuda.global.cudnn.*;
/**
* CuDNN dropout helper
*
* Note that for repeatability between calls (for example, for gradient checks), we need to do two things:
* (a) set the ND4J RNG seed
* (b) clear the rngStates field
*
* @author Alex Black
*/
@Data
@Slf4j
public class CudnnDropoutHelper extends BaseCudnnHelper implements DropoutHelper {
private static class CudnnDropoutContext extends CudnnContext {
private static class Deallocator extends CudnnDropoutContext implements Pointer.Deallocator {
Deallocator(CudnnDropoutContext c) {
super(c);
}
@Override
public void deallocate() {
destroyHandles();
}
}
private cudnnTensorStruct xTensorDesc = new cudnnTensorStruct(); //Input
private cudnnTensorStruct dxTensorDesc = new cudnnTensorStruct(); //Grad at input
private cudnnTensorStruct yTensorDesc = new cudnnTensorStruct(); //Output
private cudnnTensorStruct dyTensorDesc = new cudnnTensorStruct(); //Grad at output
private cudnnDropoutStruct dropoutDesc = new cudnnDropoutStruct();
public CudnnDropoutContext() {
createHandles();
deallocator(new Deallocator(this));
}
public CudnnDropoutContext(CudnnDropoutContext c) {
super(c);
xTensorDesc = new cudnnTensorStruct(c.xTensorDesc);
dxTensorDesc = new cudnnTensorStruct(c.dxTensorDesc);
yTensorDesc = new cudnnTensorStruct(c.yTensorDesc);
dyTensorDesc = new cudnnTensorStruct(c.dyTensorDesc);
dropoutDesc = new cudnnDropoutStruct(c.dropoutDesc);
}
@Override
protected void createHandles() {
super.createHandles();
checkCudnn(cudnnCreateTensorDescriptor(xTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(dxTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(yTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(dyTensorDesc));
checkCudnn(cudnnCreateDropoutDescriptor(dropoutDesc));
}
@Override
protected void destroyHandles() {
checkCudnn(cudnnDestroyTensorDescriptor(xTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dxTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(yTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dyTensorDesc));
checkCudnn(cudnnDestroyDropoutDescriptor(dropoutDesc));
super.destroyHandles();
}
}
private CudnnDropoutContext cudnnContext = new CudnnDropoutContext();
private boolean initializedDescriptor = false;
private DataCache rngStates; //"Pointer to user-allocated GPU memory that will hold random number generator states."
private DataCache mask; //Mask: persistence between forward and backward
private SizeTPointer stateSizeBytesPtr;
private SizeTPointer reserveSizeBytesPtr;
private float lastInitializedP;
public CudnnDropoutHelper(DataType dataType){
super(dataType);
}
//@Override
public Map<String, Long> helperMemoryUse() {
return Collections.emptyMap();
}
@Override
public boolean checkSupported() {
return true;
}
@Override
public void applyDropout(INDArray input, INDArray resultArray, double dropoutInputRetainProb) {
float p = (float)(1.0 - dropoutInputRetainProb); //CuDNN uses p = probability of setting to 0. We use p = probability of retaining
//TODO int cast
int[] inShape = adaptForTensorDescr(ArrayUtil.toInts(input.shape()));
int[] inStride = adaptForTensorDescr(ArrayUtil.toInts(input.stride()));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.xTensorDesc, dataType, inShape.length, inShape, inStride));
int[] outShape = adaptForTensorDescr(ArrayUtil.toInts(resultArray.shape()));
int[] outStride = adaptForTensorDescr(ArrayUtil.toInts(resultArray.stride()));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.yTensorDesc, dataType, outShape.length, outShape, outStride));
if(stateSizeBytesPtr == null){
stateSizeBytesPtr = new SizeTPointer(1);
reserveSizeBytesPtr = new SizeTPointer(1);
}
checkCudnn(cudnnDropoutGetStatesSize(cudnnContext, stateSizeBytesPtr));
long rngStateSizeBytes = stateSizeBytesPtr.get();
checkCudnn(cudnnDropoutGetReserveSpaceSize(cudnnContext.xTensorDesc, reserveSizeBytesPtr));
long maskReserveSizeBytes = reserveSizeBytesPtr.get();
if(rngStates == null || rngStates.capacity() < rngStateSizeBytes){
if(log.isTraceEnabled()){
if(rngStates == null){
log.trace("CudnnDropoutHelper: Allocating intial RNG states workspace of size {} ({})", rngStateSizeBytes,
BinaryByteUnit.format(rngStateSizeBytes, "#.00"));
} else {
log.trace("CudnnDropoutHelper: Deallocating RNG states of size {} ({}), allocating new workspace of size {} ({})",
rngStates.capacity(), BinaryByteUnit.format(rngStates.capacity(), "#.00"),
rngStateSizeBytes, BinaryByteUnit.format(rngStateSizeBytes, "#.00"));
}
}
if(rngStates != null)
rngStates.deallocate();
//states = "Pointer to user-allocated GPU memory that will hold random number generator states."
rngStates = new DataCache(rngStateSizeBytes);
initializedDescriptor = false;
}
if(mask == null || mask.capacity() < maskReserveSizeBytes){
if(log.isTraceEnabled()){
if(mask == null){
log.trace("CudnnDropoutHelper: Allocating intial mask array of size {} ({})", maskReserveSizeBytes,
BinaryByteUnit.format(maskReserveSizeBytes, "#.00"));
} else {
log.trace("CudnnDropoutHelper: Deallocating mask array of size {} ({}), allocating new mask array of size {} ({})",
mask.capacity(), BinaryByteUnit.format(mask.capacity(), "#.00"),
maskReserveSizeBytes, BinaryByteUnit.format(maskReserveSizeBytes, "#.00"));
}
}
if(mask != null)
mask.deallocate();
//mask = "Pointer to user-allocated GPU memory used by this function. It is expected
//that contents of reserveSpace doe not change between cudnnDropoutForward and
//cudnnDropoutBackward calls."
mask = new DataCache(maskReserveSizeBytes);
}
//Dropout descriptor: (re)initialize if required
if(!initializedDescriptor || p != lastInitializedP) {
if(log.isTraceEnabled()){
log.trace("CudnnDropoutHelper: (re)initializing dropout descriptor");
}
//NOTE: cudnnSetDropoutDescriptor has some internal computation/initialization, and hence is expensive to
// call - so we want to call this as infrequently as possible, and cache the result
long seed = Nd4j.getRandom().nextLong();
lastInitializedP = p;
checkCudnn(cudnnSetDropoutDescriptor(cudnnContext.dropoutDesc, cudnnContext, p, rngStates, rngStates.capacity(), seed));
initializedDescriptor = true;
}
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(input, resultArray);
Pointer xPtr = allocator.getPointer(input, context);
Pointer yPtr = allocator.getPointer(resultArray, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnDropoutForward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.xTensorDesc, xPtr,
cudnnContext.yTensorDesc, yPtr, mask, mask.capacity()));
allocator.registerAction(context, input, resultArray);
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
}
@Override
public void backprop(INDArray gradAtOutput, INDArray gradAtInput) {
int[] gradAtOutShape = adaptForTensorDescr(ArrayUtil.toInts(gradAtOutput.shape()));
int[] gradAtOutStride = adaptForTensorDescr(ArrayUtil.toInts(gradAtOutput.stride()));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dyTensorDesc, dataType, gradAtOutShape.length, gradAtOutShape, gradAtOutStride));
int[] gradAtInShape = adaptForTensorDescr(ArrayUtil.toInts(gradAtInput.shape()));
int[] gradAtInStride = adaptForTensorDescr(ArrayUtil.toInts(gradAtInput.stride()));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dxTensorDesc, dataType, gradAtInShape.length, gradAtInShape, gradAtInStride));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(gradAtOutput, gradAtInput);
Pointer dyPtr = allocator.getPointer(gradAtOutput, context);
Pointer dxPtr = allocator.getPointer(gradAtInput, context);
checkCudnn(cudnnDropoutBackward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.dyTensorDesc, dyPtr,
cudnnContext.dxTensorDesc, dxPtr, mask, mask.capacity()));
allocator.registerAction(context, gradAtOutput, gradAtInput);
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
}
}

View File

@ -0,0 +1,384 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.cuda.normalization;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import java.util.HashMap;
import java.util.Map;
import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.cudnn.*;
import static org.bytedeco.cuda.global.cudnn.*;
/**
* cuDNN-based helper for the batch normalization layer.
*
* @author saudet
*/
@Slf4j
public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements BatchNormalizationHelper {
public CudnnBatchNormalizationHelper(DataType dataType) {
super(dataType);
}
private static class CudnnBatchNormalizationContext extends CudnnContext {
private static class Deallocator extends CudnnBatchNormalizationContext implements Pointer.Deallocator {
Deallocator(CudnnBatchNormalizationContext c) {
super(c);
}
@Override
public void deallocate() {
destroyHandles();
}
}
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
deltaTensorDesc = new cudnnTensorStruct(), gammaBetaTensorDesc = new cudnnTensorStruct();
public CudnnBatchNormalizationContext() {
createHandles();
deallocator(new Deallocator(this));
}
public CudnnBatchNormalizationContext(CudnnBatchNormalizationContext c) {
super(c);
srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
gammaBetaTensorDesc = new cudnnTensorStruct(c.gammaBetaTensorDesc);
}
@Override
protected void createHandles() {
super.createHandles();
checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(gammaBetaTensorDesc));
}
@Override
protected void destroyHandles() {
checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(gammaBetaTensorDesc));
super.destroyHandles();
}
}
protected final int batchNormMode = CUDNN_BATCHNORM_SPATIAL; // would need to increase rank of gamma and beta for CUDNN_BATCHNORM_PER_ACTIVATION
private CudnnBatchNormalizationContext cudnnContext = new CudnnBatchNormalizationContext();
private INDArray meanCache;
private INDArray varCache;
private double eps;
public boolean checkSupported(double eps, boolean isFixedGammaBeta) {
boolean supported = checkSupported();
if (eps < CUDNN_BN_MIN_EPSILON) {
supported = false;
log.warn("Not supported: eps < CUDNN_BN_MIN_EPSILON (" + eps + " < " + CUDNN_BN_MIN_EPSILON + ")");
}
return supported;
}
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) {
boolean nchw = format == CNN2DFormat.NCHW;
this.eps = eps;
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
val miniBatch = (int) input.size(0);
val depth = (int) input.size(chIdx);
val inH = (int) input.size(hIdx);
val inW = (int) input.size(wIdx);
final boolean isHalf = (input.dataType() == DataType.HALF);
INDArray gammaOrig = null;
INDArray dGammaViewOrig = null;
INDArray dBetaViewOrig = null;
if(isHalf) { //Convert FP16 to FP32 if required (CuDNN BN doesn't support FP16 for these params, only for input/output)
gammaOrig = gamma;
dGammaViewOrig = dGammaView;
dBetaViewOrig = dBetaView;
/*
From CuDNN docs: bnScale, resultBnScaleDiff, resultBnBiasDiff, savedMean, savedInvVariance
"Note: The data type of this tensor descriptor must be 'float' for FP16 and FP32 input tensors, and 'double'
for FP64 input tensors."
>> Last 2 are the meanCache and varCache; first 3 are below
*/
gamma = gamma.castTo(DataType.FLOAT);
dGammaView = dGammaView.castTo(DataType.FLOAT);
dBetaView = dBetaView.castTo(DataType.FLOAT);
}
Gradient retGradient = new DefaultGradient();
if (!Shape.hasDefaultStridesForShape(epsilon)) {
// apparently not supported by cuDNN
epsilon = epsilon.dup('c');
}
val srcStride = ArrayUtil.toInts(input.stride());
val deltaStride = ArrayUtil.toInts(epsilon.stride());
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c');
val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma,
dGammaView, dBetaView);
Pointer srcData = allocator.getPointer(input, context);
Pointer epsData = allocator.getPointer(epsilon, context);
Pointer dstData = allocator.getPointer(nextEpsilon, context);
Pointer gammaData = allocator.getPointer(gamma, context);
Pointer dGammaData = allocator.getPointer(dGammaView, context);
Pointer dBetaData = allocator.getPointer(dBetaView, context);
Pointer meanCacheData = allocator.getPointer(meanCache, context);
Pointer varCacheData = allocator.getPointer(varCache, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha,
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
dBetaData, eps, meanCacheData, varCacheData));
allocator.getFlowController().registerActionAllWrite(context, input, epsilon, nextEpsilon, gamma, dGammaView,
dBetaView);
retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
context.syncOldStream();
//Convert back and assign, if required:
if(isHalf){
gammaOrig.assign(gamma.castTo(DataType.HALF));
dGammaViewOrig.assign(dGammaView.castTo(DataType.HALF));
dBetaViewOrig.assign(dBetaView.castTo(DataType.HALF));
}
return new Pair<>(retGradient, nextEpsilon);
}
@Override
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
boolean nchw = format == CNN2DFormat.NCHW;
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
this.eps = eps;
final boolean isHalf = (x.dataType() == DataType.FLOAT16);
INDArray origGamma = gamma;
INDArray origBeta = beta;
INDArray origMean = mean;
INDArray origVar = var;
if(isHalf) {
gamma = gamma.castTo(DataType.FLOAT);
beta = beta.castTo(DataType.FLOAT);
mean = mean.castTo(DataType.FLOAT);
var = var.castTo(DataType.FLOAT);
}
//Notation difference between CuDNN and our implementation:
//Us: runningMean = (1-decay) * batchMean + decay * runningMean
//CuDNN: runningMean = decay * batchMean + (1-decay) * runningMean
//i.e., "decay" has a different meaning...
//Disable in-place updating of running mean/variance, so that all parameter changes are done via the update/gradient
// vector. This is necessary for BatchNormalization to be safe to use in distributed gradient sharing settings
decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled"
val miniBatch = (int) x.size(0);
val inDepth = (int) x.size(chIdx);
val inH = (int) x.size(hIdx);
val inW = (int) x.size(wIdx);
val srcStride = ArrayUtil.toInts(x.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx]));
long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth};
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c');
val dstStride = ArrayUtil.toInts(activations.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context =
allocator.getFlowController().prepareActionAllWrite(x, activations, gamma, beta, mean, var);
Pointer srcData = allocator.getPointer(x, context);
Pointer dstData = allocator.getPointer(activations, context);
Pointer gammaData = allocator.getPointer(gamma, context);
Pointer betaData = allocator.getPointer(beta, context);
Pointer meanData = allocator.getPointer(mean, context);
Pointer varData = allocator.getPointer(var, context);
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
if (training) {
if(meanCache == null || meanCache.length() < mean.length()){
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
meanCache = Nd4j.createUninitialized(x.dataType(), mean.length());
}
if(x.dataType() == DataType.HALF){
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
meanCache = meanCache.castTo(DataType.FLOAT);
}
}
}
if(varCache == null || varCache.length() < mean.length()){
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
varCache = Nd4j.createUninitialized(x.dataType(), mean.length());
}
if(nd4jDataType == DataType.HALF){
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
varCache = varCache.castTo(DataType.FLOAT);
}
}
}
Pointer meanCacheData = allocator.getPointer(meanCache, context);
Pointer varCacheData = allocator.getPointer(varCache, context);
checkCudnn(cudnnBatchNormalizationForwardTraining(cudnnContext, batchNormMode, this.alpha, this.beta,
cudnnContext.srcTensorDesc, srcData, cudnnContext.dstTensorDesc, dstData,
cudnnContext.gammaBetaTensorDesc, gammaData, betaData, decay, meanData, varData, eps,
meanCacheData, varCacheData));
} else {
checkCudnn(cudnnBatchNormalizationForwardInference(cudnnContext, batchNormMode, this.alpha, this.beta,
cudnnContext.srcTensorDesc, srcData, cudnnContext.dstTensorDesc, dstData,
cudnnContext.gammaBetaTensorDesc, gammaData, betaData, meanData, varData, eps));
}
allocator.getFlowController().registerActionAllWrite(context, x, activations, gamma, beta, mean, var);
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
context.syncOldStream();
if(training) {
AtomicAllocator.getInstance().getAllocationPoint(meanCache).tickDeviceWrite();
AtomicAllocator.getInstance().getAllocationPoint(varCache).tickDeviceWrite();
}
if(training && isHalf){
//Update the running mean and variance arrays; also gamma/beta
origMean.assign(mean.castTo(DataType.HALF));
origVar.assign(var.castTo(DataType.HALF));
origGamma.assign(gamma.castTo(DataType.HALF));
origBeta.assign(beta.castTo(DataType.HALF));
}
return activations;
}
@Override
public INDArray getMeanCache(DataType dataType) {
if(dataType == DataType.HALF){
//Buffer is FP32
return meanCache.castTo(DataType.HALF);
}
return meanCache;
}
@Override
public INDArray getVarCache(DataType dataType) {
INDArray ret;
if(dataType == DataType.HALF){
INDArray vc = varCache.castTo(DataType.HALF);
ret = vc.mul(vc).rdivi(1.0).subi(eps);
} else {
ret = varCache.mul(varCache).rdivi(1.0).subi(eps);
}
if(dataType == DataType.HALF){
//Buffer is FP32
return ret.castTo(DataType.HALF);
}
return ret;
}
@Override
public Map<String, Long> helperMemoryUse() {
Map<String,Long> memUse = new HashMap<>();
memUse.put("meanCache", meanCache == null ? 0 : meanCache.length() * meanCache.data().getElementSize());
memUse.put("varCache", varCache == null ? 0 : varCache.length() * varCache.data().getElementSize());
return memUse;
}
}

View File

@ -0,0 +1,240 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.cuda.normalization;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.common.util.ArrayUtil;
import java.util.Collections;
import java.util.Map;
import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.cudnn.*;
import static org.bytedeco.cuda.global.cudnn.*;
/**
* cuDNN-based helper for the local response normalization layer.
*
* @author saudet
*/
@Slf4j
public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper implements LocalResponseNormalizationHelper {
public CudnnLocalResponseNormalizationHelper(DataType dataType) {
super(dataType);
}
private static class CudnnLocalResponseNormalizationContext extends CudnnContext {
private static class Deallocator extends CudnnLocalResponseNormalizationContext implements Pointer.Deallocator {
Deallocator(CudnnLocalResponseNormalizationContext c) {
super(c);
}
@Override
public void deallocate() {
destroyHandles();
}
}
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
deltaTensorDesc = new cudnnTensorStruct();
private cudnnLRNStruct lrnDesc = new cudnnLRNStruct();
public CudnnLocalResponseNormalizationContext() {
createHandles();
deallocator(new Deallocator(this));
}
public CudnnLocalResponseNormalizationContext(CudnnLocalResponseNormalizationContext c) {
super(c);
srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
lrnDesc = new cudnnLRNStruct(c.lrnDesc);
}
@Override
protected void createHandles() {
super.createHandles();
checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
checkCudnn(cudnnCreateLRNDescriptor(lrnDesc));
}
@Override
protected void destroyHandles() {
checkCudnn(cudnnDestroyLRNDescriptor(lrnDesc));
checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
super.destroyHandles();
}
}
private CudnnLocalResponseNormalizationContext cudnnContext = new CudnnLocalResponseNormalizationContext();
private INDArray activations = null;
public boolean checkSupported(double k, double n, double alpha, double beta) {
boolean supported = checkSupported();
if (n < CUDNN_LRN_MIN_N) {
supported = false;
log.warn("Not supported: n < CUDNN_LRN_MIN_N (" + n + " < " + CUDNN_LRN_MIN_N + ")");
}
if (n > CUDNN_LRN_MAX_N) {
supported = false;
log.warn("Not supported: n > CUDNN_LRN_MAX_N (" + n + " > " + CUDNN_LRN_MAX_N + ")");
}
if (k < CUDNN_LRN_MIN_K) {
supported = false;
log.warn("Not supported: k < CUDNN_LRN_MIN_K (" + k + " < " + CUDNN_LRN_MIN_K + ")");
}
if (beta < CUDNN_LRN_MIN_BETA) {
supported = false;
log.warn("Not supported: beta < CUDNN_LRN_MIN_BETA (" + beta + " < " + CUDNN_LRN_MIN_BETA + ")");
}
return supported;
}
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, double k, double n, double alpha,
double beta, LayerWorkspaceMgr workspaceMgr) {
val miniBatch = (int) input.size(0);
val depth = (int) input.size(1);
val inH = (int) input.size(2);
val inW = (int) input.size(3);
Gradient retGradient = new DefaultGradient();
if (!Shape.hasDefaultStridesForShape(epsilon)) {
// apparently not supported by cuDNN
epsilon = epsilon.dup('c');
}
val srcStride = ArrayUtil.toInts(input.stride());
val deltaStride = ArrayUtil.toInts(epsilon.stride());
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, depth, inH, inW,
srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, miniBatch, depth, inH, inW,
deltaStride[0], deltaStride[1], deltaStride[2], deltaStride[3]));
checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k));
INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c');
val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context =
allocator.getFlowController().prepareActionAllWrite(input, epsilon, activations, nextEpsilon);
Pointer srcData = allocator.getPointer(input, context);
Pointer epsData = allocator.getPointer(epsilon, context);
Pointer zData = allocator.getPointer(activations, context);
Pointer dstData = allocator.getPointer(nextEpsilon, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnLRNCrossChannelBackward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
this.alpha, cudnnContext.deltaTensorDesc, zData, cudnnContext.deltaTensorDesc, epsData,
cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, dstData));
allocator.getFlowController().registerActionAllWrite(context, input, epsilon, activations, nextEpsilon);
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
return new Pair<>(retGradient, nextEpsilon);
}
@Override
public INDArray activate(INDArray input, boolean training, double k, double n, double alpha, double beta, LayerWorkspaceMgr workspaceMgr) {
val miniBatch = (int) input.size(0);
val inDepth = (int) input.size(1);
val inH = (int) input.size(2);
val inW = (int) input.size(3);
if(!Shape.hasDefaultStridesForShape(input)){
input = input.dup('c');
}
val srcStride = ArrayUtil.toInts(input.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c');
val dstStride = ArrayUtil.toInts(activations.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, activations);
Pointer srcData = allocator.getPointer(input, context);
Pointer dstData = allocator.getPointer(activations, context);
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnLRNCrossChannelForward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
this.alpha, cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc,
dstData));
allocator.getFlowController().registerActionAllWrite(context, input, activations);
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
return activations;
}
@Override
public Map<String, Long> helperMemoryUse() {
//No persistent memory use other than the structs (which are small)
return Collections.emptyMap();
}
}

View File

@ -0,0 +1,659 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.cuda.recurrent;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import com.jakewharton.byteunits.BinaryByteUnit;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.nn.workspace.ArrayType;
import java.util.HashMap;
import java.util.Map;
import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.cudnn.*;
import static org.bytedeco.cuda.global.cudart.*;
import static org.bytedeco.cuda.global.cudnn.*;
/**
* cuDNN-based helper for the recurrent LSTM layer (no peephole connections).
*
* @author saudet
*/
@Slf4j
public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
public CudnnLSTMHelper(DataType dataType) {
super(dataType);
}
private static class CudnnLSTMContext extends CudnnContext {
private static class Deallocator extends CudnnLSTMContext implements Pointer.Deallocator {
Deallocator(CudnnLSTMContext c) {
super(c);
}
@Override
public void deallocate() {
destroyHandles();
}
}
private cudnnTensorStruct hxDesc = new cudnnTensorStruct(), cxDesc = new cudnnTensorStruct();
private cudnnTensorStruct hyDesc = new cudnnTensorStruct(), cyDesc = new cudnnTensorStruct();
private cudnnTensorStruct dhxDesc = new cudnnTensorStruct(), dcxDesc = new cudnnTensorStruct();
private cudnnTensorStruct dhyDesc = new cudnnTensorStruct(), dcyDesc = new cudnnTensorStruct();
private cudnnFilterStruct wDesc = new cudnnFilterStruct(), dwDesc = new cudnnFilterStruct();
private cudnnFilterStruct linLayerMatDesc = new cudnnFilterStruct(), linLayerBiasDesc = new cudnnFilterStruct();
private cudnnRNNStruct rnnDesc = new cudnnRNNStruct();
private cudnnDropoutStruct dropoutDesc = new cudnnDropoutStruct();
private cudnnActivationStruct activationDesc = new cudnnActivationStruct();
public CudnnLSTMContext() {
createHandles();
deallocator(new Deallocator(this));
}
public CudnnLSTMContext(CudnnLSTMContext c) {
super(c);
hxDesc = new cudnnTensorStruct(c.hxDesc);
cxDesc = new cudnnTensorStruct(c.cxDesc);
hyDesc = new cudnnTensorStruct(c.hyDesc);
cyDesc = new cudnnTensorStruct(c.cyDesc);
dhxDesc = new cudnnTensorStruct(c.dhxDesc);
dcxDesc = new cudnnTensorStruct(c.dcxDesc);
dhyDesc = new cudnnTensorStruct(c.dhyDesc);
dcyDesc = new cudnnTensorStruct(c.dcyDesc);
wDesc = new cudnnFilterStruct(c.wDesc);
dwDesc = new cudnnFilterStruct(c.dwDesc);
linLayerMatDesc = new cudnnFilterStruct(c.linLayerMatDesc);
linLayerBiasDesc = new cudnnFilterStruct(c.linLayerBiasDesc);
rnnDesc = new cudnnRNNStruct(c.rnnDesc);
dropoutDesc = new cudnnDropoutStruct(c.dropoutDesc);
activationDesc = new cudnnActivationStruct(c.activationDesc);
}
@Override
protected void createHandles() {
super.createHandles();
checkCudnn(cudnnCreateTensorDescriptor(hxDesc));
checkCudnn(cudnnCreateTensorDescriptor(cxDesc));
checkCudnn(cudnnCreateTensorDescriptor(hyDesc));
checkCudnn(cudnnCreateTensorDescriptor(cyDesc));
checkCudnn(cudnnCreateTensorDescriptor(dhxDesc));
checkCudnn(cudnnCreateTensorDescriptor(dcxDesc));
checkCudnn(cudnnCreateTensorDescriptor(dhyDesc));
checkCudnn(cudnnCreateTensorDescriptor(dcyDesc));
checkCudnn(cudnnCreateFilterDescriptor(wDesc));
checkCudnn(cudnnCreateFilterDescriptor(dwDesc));
checkCudnn(cudnnCreateFilterDescriptor(linLayerMatDesc));
checkCudnn(cudnnCreateFilterDescriptor(linLayerBiasDesc));
checkCudnn(cudnnCreateRNNDescriptor(rnnDesc));
checkCudnn(cudnnCreateDropoutDescriptor(dropoutDesc));
checkCudnn(cudnnCreateActivationDescriptor(activationDesc));
}
@Override
protected void destroyHandles() {
checkCudnn(cudnnDestroyActivationDescriptor(activationDesc));
checkCudnn(cudnnDestroyDropoutDescriptor(dropoutDesc));
checkCudnn(cudnnDestroyRNNDescriptor(rnnDesc));
checkCudnn(cudnnDestroyFilterDescriptor(wDesc));
checkCudnn(cudnnDestroyFilterDescriptor(dwDesc));
checkCudnn(cudnnDestroyFilterDescriptor(linLayerMatDesc));
checkCudnn(cudnnDestroyFilterDescriptor(linLayerBiasDesc));
checkCudnn(cudnnDestroyTensorDescriptor(hxDesc));
checkCudnn(cudnnDestroyTensorDescriptor(cxDesc));
checkCudnn(cudnnDestroyTensorDescriptor(hyDesc));
checkCudnn(cudnnDestroyTensorDescriptor(cyDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dhxDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dcxDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dhyDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dcyDesc));
super.destroyHandles();
}
}
// These constants might eventually become variable parameters...
protected static final int NUM_LAYERS = 1;
protected static final float DROPOUT = 0;
protected static final boolean BIDIRECTIONAL = false;
protected static final int RNN_MODE = CUDNN_LSTM;
protected static final int NUM_LINEAR_LAYERS = 8; // CUDNN_LSTM
private CudnnLSTMContext cudnnContext = new CudnnLSTMContext();
private TensorArray xDesc = new TensorArray();
private TensorArray yDesc = new TensorArray();
private TensorArray dxDesc = new TensorArray();
private TensorArray dyDesc = new TensorArray();
private DataCache stateSpace = new DataCache();
private DataCache reserveSpace = new DataCache();
private DataCache weightsSpace = new DataCache();
private boolean initializedDropoutDescriptor = false;
private static INDArray toCOrder(INDArray arr) {
if (arr.isView() || arr.ordering() != 'c' || !Shape.strideDescendingCAscendingF(arr)) {
arr = arr.dup('c');
}
return arr;
}
@Override
public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn,
boolean hasPeepholeConnections) {
boolean supported = checkSupported();
if (!(gateActivationFn instanceof ActivationSigmoid)) {
supported = false;
log.warn("Not supported: Gate activation functions != ActivationSigmoid");
}
if (!(activationFn instanceof ActivationTanH)) {
supported = false;
log.warn("Not supported: Layer activation functions != ActivationTanH");
}
if (hasPeepholeConnections) {
supported = false;
log.warn("Not supported: LSTM layers with peephole connections");
}
return supported;
}
@Override
public Pair<Gradient, INDArray> backpropGradient(final NeuralNetConfiguration conf,
final IActivation gateActivationFn, final INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray inputWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray epsilon, final boolean truncatedBPTT, final int tbpttBackwardLength,
final FwdPassReturn fwdPass, final boolean forwards, final String inputWeightKey,
final String recurrentWeightKey, final String biasWeightKey,
final Map<String, INDArray> gradientViews, INDArray maskArray, //Input mask: should only be used with bidirectional RNNs + variable length
final boolean hasPeepholeConnections, //True for GravesLSTM, false for LSTM
final LayerWorkspaceMgr workspaceMgr) {
//Expect errors to have shape: [miniBatchSize,n^(L+1),timeSeriesLength]
val hiddenLayerSize = recurrentWeights.size(0); //i.e., n^L
val prevLayerSize = inputWeights.size(0); //n^(L-1)
val inputLayerSize = input.size(1);
val miniBatchSize = epsilon.size(0);
boolean is2dInput = epsilon.rank() < 3; //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]
long timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
INDArray x = toCOrder(input.permute(2, 0, 1));
INDArray dy = toCOrder(epsilon.permute(2, 0, 1));
INDArray dx = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, inputWeights.dataType(), new long[] {timeSeriesLength, miniBatchSize, prevLayerSize}, 'c');
INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey); //Order: {I,F,O,G}
INDArray bGradientsOut = gradientViews.get(biasWeightKey);
INDArray outputActivations = toCOrder(fwdPass.fwdPassOutput.permute(2, 0, 1));
INDArray prevStepMemCellState = toCOrder(fwdPass.prevMemCell);
INDArray prevStepActivations = toCOrder(fwdPass.prevAct);
Nd4j.getExecutioner().commit();
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(x, dy, dx, outputActivations,
prevStepMemCellState, prevStepActivations, iwGradientsOut, rwGradientsOut, bGradientsOut);
Pointer xData = allocator.getPointer(x, context);
Pointer dyData = allocator.getPointer(dy, context);
Pointer dxData = allocator.getPointer(dx, context);
Pointer outputActivationsData = allocator.getPointer(outputActivations, context);
Pointer prevMemCellStateData = allocator.getPointer(prevStepMemCellState, context);
Pointer prevStepActivationsData = allocator.getPointer(prevStepActivations, context);
Pointer iwGradientsOutData = allocator.getPointer(iwGradientsOut, context);
Pointer rwGradientsOutData = allocator.getPointer(rwGradientsOut, context);
Pointer bGradientsOutData = allocator.getPointer(bGradientsOut, context);
CUstream_st stream = new CUstream_st(context.getCublasStream());
checkCudnn(cudnnSetStream(cudnnContext, stream));
if (truncatedBPTT) {
val endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength) * miniBatchSize * hiddenLayerSize;
xData.position(endIdx * dataTypeSize);
dyData.position(endIdx * (BIDIRECTIONAL ? 2 : 1) * dataTypeSize);
outputActivationsData.position(endIdx * (BIDIRECTIONAL ? 2 : 1) * dataTypeSize);
timeSeriesLength = (int) Math.min(timeSeriesLength, tbpttBackwardLength);
}
cudnnTensorStruct xDesc0 = xDesc.get(cudnnTensorStruct.class, 0);
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
checkCudnn(cudnnRNNBackwardData(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, yDesc,
outputActivationsData, dyDesc, dyData, cudnnContext.dhyDesc, null, cudnnContext.dcyDesc, null,
cudnnContext.wDesc, weightsSpace, cudnnContext.hxDesc, prevStepActivationsData, //hx: initial hidden state of RNN
cudnnContext.cxDesc, prevMemCellStateData, //cx: initial cell state of RNN
dxDesc, dxData, //dx: gradient at input of each time step
cudnnContext.dhxDesc, null, //dhx: gradient at initial hidden state of RNN
cudnnContext.dcxDesc, null, //dcx: Gradient at initial cell state
workSpace, workSpace.limit(), reserveSpace, reserveSpace.limit()));
// cudnnRNNBackwardWeights adds to the data in dW.
checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream));
checkCudnn(cudnnRNNBackwardWeights(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData, //Input data
cudnnContext.hxDesc, prevStepActivationsData, //Initial hidden state
yDesc, outputActivationsData, //Output data
workSpace, workSpace.limit(), cudnnContext.dwDesc, weightsSpace, reserveSpace,
reserveSpace.limit()));
int[] dataType = new int[1];
int[] format = new int[1];
int[] nbDims = new int[1];
int[] filterDimA = new int[3];
Pointer linLayerMat = new Pointer();
Pointer linLayerBias = new Pointer();
for (int layer = 0; layer < NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1); layer++) {
for (int linLayerID = 0; linLayerID < NUM_LINEAR_LAYERS; linLayerID++) {
checkCudnn(cudnnGetRNNLinLayerMatrixParams(cudnnContext, cudnnContext.rnnDesc, layer, xDesc0,
cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerMatDesc,
linLayerMat));
checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerMatDesc, 3, dataType, format, nbDims,
filterDimA));
checkCudnn(cudnnGetRNNLinLayerBiasParams(cudnnContext, cudnnContext.rnnDesc, layer, xDesc0,
cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerBiasDesc,
linLayerBias));
checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerBiasDesc, 3, dataType, format, nbDims,
filterDimA));
// our data is in "new, forget, output, and input gates" order (aka IFOG), each kind of weight packed together
int position = 0;
long size = 0;
Pointer data = null;
switch (linLayerID) {
case 0:
data = iwGradientsOutData;
position = 3;
size = inputLayerSize;
break; // input gate
case 1:
data = iwGradientsOutData;
position = 1;
size = inputLayerSize;
break; // forget gate
case 2:
data = iwGradientsOutData;
position = 0;
size = inputLayerSize;
break; // new gate (input modulation gate)
case 3:
data = iwGradientsOutData;
position = 2;
size = inputLayerSize;
break; // output gate
case 4:
data = rwGradientsOutData;
position = 3;
size = hiddenLayerSize;
break; // input gate
case 5:
data = rwGradientsOutData;
position = 1;
size = hiddenLayerSize;
break; // forget gate
case 6:
data = rwGradientsOutData;
position = 0;
size = hiddenLayerSize;
break; // new gate (input modulation gate)
case 7:
data = rwGradientsOutData;
position = 2;
size = hiddenLayerSize;
break; // output gate
default:
throw new RuntimeException();
}
checkCuda(cudaMemcpyAsync(data.position(position * size * hiddenLayerSize * dataTypeSize), linLayerMat,
size * hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
if (linLayerID < 4) {
checkCuda(cudaMemcpyAsync(bGradientsOutData.position(position * hiddenLayerSize * dataTypeSize),
linLayerBias, hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
}
}
}
allocator.getFlowController().registerActionAllWrite(context, x, dy, dx, outputActivations,
prevStepMemCellState, prevStepActivations, iwGradientsOut, rwGradientsOut, bGradientsOut);
Gradient retGradient = new DefaultGradient();
retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
INDArray epsilonNext = dx.permute(1, 2, 0); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]
return new Pair<>(retGradient, epsilonNext);
}
@Override
public FwdPassReturn activate(final Layer layer, final NeuralNetConfiguration conf,
final IActivation gateActivationFn, //Activation function for the gates - sigmoid or hard sigmoid (must be found in range 0 to 1)
INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray inputWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray biases, //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
final boolean training, final INDArray prevOutputActivations, final INDArray prevMemCellState,
boolean forBackprop, boolean forwards, final String inputWeightKey, INDArray maskArray, //Input mask: should only be used with bidirectional RNNs + variable length
final boolean hasPeepholeConnections, //True for GravesLSTM, false for LSTM
final LayerWorkspaceMgr workspaceMgr) {
boolean is2dInput = input.rank() < 3; //Edge case of T=1, may have shape [m,nIn], equiv. to [m,nIn,1]
val timeSeriesLength = (is2dInput ? 1 : input.size(2));
val hiddenLayerSize = recurrentWeights.size(0);
val miniBatchSize = input.size(0);
val inputLayerSize = input.size(1);
INDArray x = toCOrder(input.permute(2, 0, 1));
INDArray linInputWeights = inputWeights;
INDArray linRecurrentWeights = recurrentWeights;
INDArray linBiases = biases;
INDArray prevAct = toCOrder(prevOutputActivations);
INDArray prevMemCell = toCOrder(prevMemCellState);
INDArray outputActivations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS,
inputWeights.dataType(), new long[] {timeSeriesLength, miniBatchSize, hiddenLayerSize * (BIDIRECTIONAL ? 2 : 1)}, 'c');
INDArray finalMemCellState = Nd4j.createUninitialized( inputWeights.dataType(),
new long[] {/*numLayers * (bidirectional ? 2 : 1),*/ miniBatchSize, hiddenLayerSize}, 'c');
INDArray finalStepActivations = Nd4j.createUninitialized( inputWeights.dataType(),
new long[] {/*numLayers * (bidirectional ? 2 : 1),*/ miniBatchSize, hiddenLayerSize}, 'c');
FwdPassReturn toReturn = new FwdPassReturn();
toReturn.prevAct = prevAct;
toReturn.prevMemCell = prevMemCell;
Nd4j.getExecutioner().commit();
if (timeSeriesLength > xDesc.capacity()) {
xDesc.deallocate();
xDesc = new TensorArray(timeSeriesLength);
}
if (timeSeriesLength > yDesc.capacity()) {
yDesc.deallocate();
yDesc = new TensorArray(timeSeriesLength);
}
if (timeSeriesLength > dxDesc.capacity()) {
dxDesc.deallocate();
dxDesc = new TensorArray(timeSeriesLength);
}
if (timeSeriesLength > dyDesc.capacity()) {
dyDesc.deallocate();
dyDesc = new TensorArray(timeSeriesLength);
}
for (int i = 0; i < timeSeriesLength; i++) {
int[] dimA = {(int) miniBatchSize, (int) inputLayerSize, 1};
int[] strideA = {(int) dimA[2] * dimA[1], dimA[2], 1};
checkCudnn(cudnnSetTensorNdDescriptor(xDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimA, strideA));
checkCudnn(cudnnSetTensorNdDescriptor(dxDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimA, strideA));
int[] dimB = {(int) miniBatchSize, (int) hiddenLayerSize * (BIDIRECTIONAL ? 2 : 1), 1};
int[] strideB = {dimB[2] * dimB[1], dimB[2], 1};
checkCudnn(cudnnSetTensorNdDescriptor(yDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimB, strideB));
checkCudnn(cudnnSetTensorNdDescriptor(dyDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimB, strideB));
}
int[] dimC = {NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1), (int) miniBatchSize, (int) hiddenLayerSize};
int[] strideC = {dimC[2] * dimC[1], dimC[2], 1};
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.hxDesc, dataType, 3, dimC, strideC));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.cxDesc, dataType, 3, dimC, strideC));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.hyDesc, dataType, 3, dimC, strideC));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.cyDesc, dataType, 3, dimC, strideC));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dhxDesc, dataType, 3, dimC, strideC));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dcxDesc, dataType, 3, dimC, strideC));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dhyDesc, dataType, 3, dimC, strideC));
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dcyDesc, dataType, 3, dimC, strideC));
checkCudnn(cudnnDropoutGetStatesSize(cudnnContext, sizeInBytes));
long stateSize = sizeInBytes.get(0);
if (stateSize > stateSpace.capacity()) {
stateSpace.deallocate();
stateSpace = new DataCache(stateSize);
}
stateSpace.limit(stateSize);
if(!initializedDropoutDescriptor) {
checkCudnn(cudnnSetDropoutDescriptor(cudnnContext.dropoutDesc, cudnnContext, DROPOUT, stateSpace, stateSize,
Nd4j.getRandom().getSeed()));
}
checkCudnn(cudnnSetRNNDescriptor_v6(cudnnContext, cudnnContext.rnnDesc, (int) hiddenLayerSize, NUM_LAYERS, cudnnContext.dropoutDesc,
CUDNN_LINEAR_INPUT, BIDIRECTIONAL ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, RNN_MODE,
CUDNN_RNN_ALGO_STANDARD, dataType));
cudnnTensorStruct xDesc0 = xDesc.get(cudnnTensorStruct.class, 0);
checkCudnn(cudnnGetRNNParamsSize(cudnnContext, cudnnContext.rnnDesc, xDesc0, sizeInBytes, dataType));
long weightsSize = sizeInBytes.get(0);
if (weightsSize > weightsSpace.capacity()) {
weightsSpace.deallocate();
weightsSpace = new DataCache(weightsSize);
}
weightsSpace.limit(weightsSize);
int[] dimW = {(int) weightsSize / dataTypeSize, 1, 1};
checkCudnn(cudnnSetFilterNdDescriptor(cudnnContext.wDesc, dataType, CUDNN_TENSOR_NCHW, 3, dimW));
checkCudnn(cudnnSetFilterNdDescriptor(cudnnContext.dwDesc, dataType, CUDNN_TENSOR_NCHW, 3, dimW));
checkCudnn(cudnnGetRNNWorkspaceSize(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, sizeInBytes));
long workSize = sizeInBytes.get(0);
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
if (workSpace == null || workSize > workSpace.capacity()) {
if(log.isTraceEnabled()){
if(workSpace == null){
log.trace("CudnnLSTMHelper activate: Allocating initial workspace of size {} ({})", workSize,
BinaryByteUnit.format(workSize, "#.00"));
} else {
log.trace("CudnnLSTMHelper activate: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})",
workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"),
workSize, BinaryByteUnit.format(workSize, "#.00"));
}
}
if(workSpace != null)
workSpace.deallocate();
workSpace = new DataCache(workSize);
workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
}
workSpace.limit(workSize);
checkCudnn(cudnnGetRNNTrainingReserveSize(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc,
sizeInBytes));
long reserveSize = sizeInBytes.get(0);
if (reserveSize > reserveSpace.capacity()) {
reserveSpace.deallocate();
reserveSpace = new DataCache(reserveSize);
}
reserveSpace.limit(reserveSize);
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(x, linInputWeights,
linRecurrentWeights, linBiases, prevAct, prevMemCell, outputActivations, finalMemCellState,
finalStepActivations);
Pointer xData = allocator.getPointer(x, context);
Pointer linInputWeightsData = allocator.getPointer(linInputWeights, context);
Pointer linRecurrentWeightsData = allocator.getPointer(linRecurrentWeights, context);
Pointer linBiasesData = allocator.getPointer(linBiases, context);
Pointer prevActData = allocator.getPointer(prevAct, context);
Pointer prevMemCellData = allocator.getPointer(prevMemCell, context);
Pointer outputActivationsData = allocator.getPointer(outputActivations, context);
Pointer finalMemCellStateData = allocator.getPointer(finalMemCellState, context);
Pointer finalTimeStepActivationsData = allocator.getPointer(finalStepActivations, context);
CUstream_st stream = new CUstream_st(context.getCublasStream());
checkCudnn(cudnnSetStream(cudnnContext, stream));
checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream));
int[] dataType = new int[1];
int[] format = new int[1];
int[] nbDims = new int[1];
int[] filterDimA = new int[3];
Pointer linLayerMat = new Pointer();
Pointer linLayerBias = new Pointer();
for (int layerNum = 0; layerNum < NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1); layerNum++) {
for (int linLayerID = 0; linLayerID < NUM_LINEAR_LAYERS; linLayerID++) {
checkCudnn(cudnnGetRNNLinLayerMatrixParams(cudnnContext, cudnnContext.rnnDesc, layerNum, xDesc0,
cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerMatDesc,
linLayerMat));
checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerMatDesc, 3, dataType, format, nbDims,
filterDimA));
checkCudnn(cudnnGetRNNLinLayerBiasParams(cudnnContext, cudnnContext.rnnDesc, layerNum, xDesc0,
cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerBiasDesc,
linLayerBias));
checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerBiasDesc, 3, dataType, format, nbDims,
filterDimA));
// our data is in "new, forget, output, and input gates" order (aka IFOG), each kind of weight packed together
int position = 0;
long size = 0;
Pointer data = null;
switch (linLayerID) {
case 0:
data = linInputWeightsData;
position = 3;
size = inputLayerSize;
break; // input gate
case 1:
data = linInputWeightsData;
position = 1;
size = inputLayerSize;
break; // forget gate
case 2:
data = linInputWeightsData;
position = 0;
size = inputLayerSize;
break; // new gate
case 3:
data = linInputWeightsData;
position = 2;
size = inputLayerSize;
break; // output gate
case 4:
data = linRecurrentWeightsData;
position = 3;
size = hiddenLayerSize;
break; // input gate
case 5:
data = linRecurrentWeightsData;
position = 1;
size = hiddenLayerSize;
break; // forget gate
case 6:
data = linRecurrentWeightsData;
position = 0;
size = hiddenLayerSize;
break; // new gate
case 7:
data = linRecurrentWeightsData;
position = 2;
size = hiddenLayerSize;
break; // output gate
default:
throw new RuntimeException();
}
checkCuda(cudaMemcpyAsync(linLayerMat, data.position(position * size * hiddenLayerSize * dataTypeSize),
size * hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
if (linLayerID < 4) {
checkCuda(cudaMemcpyAsync(linLayerBias,
linBiasesData.position(position * hiddenLayerSize * dataTypeSize),
hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
}
}
}
if (training) {
checkCudnn(cudnnRNNForwardTraining(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData,
cudnnContext.hxDesc, prevActData, cudnnContext.cxDesc, prevMemCellData, cudnnContext.wDesc,
weightsSpace, yDesc, outputActivationsData, cudnnContext.hyDesc,
finalTimeStepActivationsData, cudnnContext.cyDesc, finalMemCellStateData, workSpace,
workSpace.limit(), reserveSpace, reserveSpace.limit()));
} else {
checkCudnn(cudnnRNNForwardInference(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData,
cudnnContext.hxDesc, prevActData, cudnnContext.cxDesc, prevMemCellData, cudnnContext.wDesc,
weightsSpace, yDesc, outputActivationsData, cudnnContext.hyDesc,
finalTimeStepActivationsData, cudnnContext.cyDesc, finalMemCellStateData, workSpace,
workSpace.limit()));
}
allocator.getFlowController().registerActionAllWrite(context, x, linInputWeights, linRecurrentWeights,
linBiases, prevAct, prevMemCell, outputActivations, finalMemCellState, finalStepActivations);
toReturn.fwdPassOutput = outputActivations.permute(1, 2, 0);
toReturn.lastAct = finalStepActivations;
toReturn.lastMemCell = finalMemCellState;
toReturn.prevAct = prevAct;
toReturn.prevMemCell = prevMemCell;
return toReturn;
}
@Override
public Map<String, Long> helperMemoryUse() {
Map<String,Long> memUse = new HashMap<>();
memUse.put("stateStace", stateSpace.capacity());
memUse.put("reserveSpace", reserveSpace.capacity());
memUse.put("weightsSpace", weightsSpace.capacity());
return memUse;
}
}

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.preprocessors; package org.deeplearning4j.nn.modelimport.keras.preprocessors;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
@ -32,6 +33,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper=false)
public class KerasFlattenRnnPreprocessor extends BaseInputPreProcessor { public class KerasFlattenRnnPreprocessor extends BaseInputPreProcessor {
private long tsLength; private long tsLength;

View File

@ -1,29 +1,22 @@
plugins { plugins {
id 'java-library' id 'java-library'
id 'maven-publish' id 'maven-publish'
id 'com.github.johnrengelman.shadow' version '7.1.2'
} }
/* apply from: rootProject.projectDir.path+"/chooseBackend.gradle"
configurations.archives.artifacts.with { archives ->
archives.each {
println(it.name)
}
}
*/
dependencies { dependencies {
afterEvaluate {
//Todo clean this //Todo clean this
api platform(project(":cavis-common-platform")) api platform(project(":cavis-common-platform"))
//api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise //api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise
api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" //api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5"
api 'org.slf4j:slf4j-simple:2.0.3' //api 'org.slf4j:slf4j-simple:2.0.3'
api 'org.slf4j:slf4j-api:2.0.3' //api 'org.slf4j:slf4j-api:2.0.3'
//TODO for the two below.. either platform specific uber jars or a single big one with all platforms //TODO for the two below.. either platform specific uber jars or a single big one with all platforms
api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" //api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64"
//api group: "org.bytedeco", name: "javacpp", version: "1.5.7"
// api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu"
//api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT'
rootProject.getAllprojects().each { Project sproj -> rootProject.getAllprojects().each { Project sproj ->
if (!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") if (!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform")
&& !sproj.name.equals("Cavis") && !sproj.name.equals("Cavis")
@ -33,26 +26,41 @@ dependencies {
&& !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-nd4j")
&& !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-ui")
&& !sproj.name.equals("cavis-zoo")) { && !sproj.name.equals("cavis-zoo")) {
//compileOnly project(""+sproj.path) api project(path: sproj.path, configuration: 'runtimeElements')
api sproj }
if(! sproj.configurations.empty) { }
//compileOnly project(sproj.getPath()) // if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements")
// if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportApiElements")
/*
api(projects.cavisNative.cavisNativeLib) {
capabilities {
//if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version)
if (withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version)
}
}
api(projects.cavisNative.cavisNativeLib) {
capabilities {
if (withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version)
//if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version)
}
}
*/
//if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation")
//if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation")
//if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath")
/* /*
sproj.configurations.each {Configuration conf -> api (project(':cavis-native:cavis-native-lib')) {
conf.dependencies.each {Dependency dep -> capabilities {
compileOnly dep if(withCpu()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cpu-support")
//if(withCuda()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cuda-support")
} }
} }
*/ */
} }
} }
}
}
/* /*
tasks.getByName("jar") { tasks.getByName("jar") {
@ -77,19 +85,39 @@ tasks.getByName("jar") {
} }
} }
/*
/*
artifacts {
archives customFatJar
}
*/ */
artifacts {
archives shadowJar
}
shadowJar {
enabled true;
zip64 true //need this to support jars with more than 65535 entries
archiveClassifier.set('')
}
publishing { publishing {
publications { publications {
mavenJava(MavenPublication) { /*mavenJava(MavenPublication) {
//artifact customFatJar //artifact customFatJar
// from components.java // from components.java
/* pom.withXml {
def dependenciesNode = asNode().dependencies
def dependencyNode = dependenciesNode.appendNode()
dependencyNode.appendNode('groupId', 'net.brutex.cavis')
dependencyNode.appendNode('artifactId', 'cavis-native-lib')
dependencyNode.appendNode('version', '1.0.0-SNAPSHOT')
//dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu')
//dependencyNode.appendNode('scope', 'compile')
}
}
*/
shadow(MavenPublication) { publication ->
project.shadow.component(publication)
} }
} }
} }

View File

@ -11,7 +11,8 @@ ext {
dependencies { dependencies {
implementation platform(projects.cavisCommonPlatform) implementation platform(projects.cavisCommonPlatform)
implementation project(":cavis-native:cavis-native-blas") //implementation project(":cavis-native:cavis-native-blas")
implementation projects.cavisNative.cavisNativeBlas
implementation group: "org.bytedeco", name: "cuda" implementation group: "org.bytedeco", name: "cuda"
implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget

View File

@ -121,7 +121,7 @@ endfunction()
if (SD_CUDA) if (SD_CUDA)
#enable_language(CUDA) #enable_language(CUDA)
find_package(CUDAToolkit 11.2 REQUIRED) find_package(CUDAToolkit 11.4 REQUIRED)
message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}") message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}")
message(STATUS "CUDAToolkit_VERSION_MAJOR: ${CUDAToolkit_VERSION_MAJOR}") message(STATUS "CUDAToolkit_VERSION_MAJOR: ${CUDAToolkit_VERSION_MAJOR}")
message(STATUS "CUDAToolkit_VERSION_MINOR: ${CUDAToolkit_VERSION_MINOR}") message(STATUS "CUDAToolkit_VERSION_MINOR: ${CUDAToolkit_VERSION_MINOR}")

View File

@ -20,11 +20,9 @@
*/ */
ext { ext {
chip = (properties.CAVIS_CHIP ?: "cuda,cpu").toLowerCase() //the default is to build for CPU and CUDA chip = (properties.CAVIS_CHIP ?: "cuda,cpu").toLowerCase() //the default is to build for CPU and CUDA
testChip = (properties.CAVIS_TEST_CHIP ?: " ").toLowerCase() //the default is without specific backend logger.debug("Building for chips ${chip} and running tests with backends for ${chip}")
logger.debug("Building for chips ${chip} and running tests with backends for ${testChip}")
chipList = chip.split(",") chipList = chip.split(",")
testChipList = testChip.split(",")
/* just for usability */ /* just for usability */
withCuda = { -> withCuda = { ->
@ -33,10 +31,4 @@ ext {
withCpu = { -> withCpu = { ->
return chip.contains("cpu") return chip.contains("cpu")
} }
withCudaTest = { ->
return testChip.contains("cuda")
}
withCpuTest = { ->
return testChip.contains("cpu")
}
} }

View File

@ -24,7 +24,7 @@ ext {
buildTarget = rootProject.ext.buildTarget buildTarget = rootProject.ext.buildTarget
apply from: new File("${project.rootProject.projectDir}/chooseBackend.gradle") apply from: new File("${project.rootProject.projectDir}/chooseBackend.gradle")
testChipList.each { thisChip -> chipList.each { thisChip ->
configurations.register("${thisChip}TestImplementation") { configurations.register("${thisChip}TestImplementation") {
it.extendsFrom configurations.testImplementation, configurations.implementation it.extendsFrom configurations.testImplementation, configurations.implementation
@ -79,33 +79,44 @@ ext {
dependencies { dependencies {
if (withCudaTest()) { if (withCuda()) {
cudaTestRuntime platform(projects.cavisCommonPlatform) cudaTestRuntime platform(projects.cavisCommonPlatform)
cudaTestRuntime projects.cavisNative.cavisNativeJcublas cudaTestRuntime projects.cavisNative.cavisNativeJcublas
cudaTestRuntime projects.cavisDnn.cavisDnnCudnn
cudaTestRuntime group: "org.bytedeco", name: "openblas" cudaTestRuntime group: "org.bytedeco", name: "openblas"
cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
cudaTestRuntime group: "org.bytedeco", name: "cuda" cudaTestRuntime group: "org.bytedeco", name: "cuda"
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist" cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
cudaTestRuntime (project( path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements"))
/*
cudaTestRuntime(project(":cavis-native:cavis-native-lib")) { cudaTestRuntime(project(":cavis-native:cavis-native-lib")) {
capabilities { capabilities {
it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cuda-support:1.0.0-SNAPSHOT" it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cuda-support:1.0.0-SNAPSHOT"
} }
} }
*/
} }
if (withCpuTest()) { if (withCpu()) {
cpuTestRuntime platform(projects.cavisCommonPlatform) cpuTestRuntime platform(projects.cavisCommonPlatform)
cpuTestRuntime projects.cavisNative.cavisNativeCpu cpuTestRuntime projects.cavisNative.cavisNativeCpu
cpuTestRuntime group: "org.bytedeco", name: "openblas" cpuTestRuntime group: "org.bytedeco", name: "openblas"
cpuTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget cpuTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
cpuTestRuntime group: "org.bytedeco", name: "opencv" cpuTestRuntime group: "org.bytedeco", name: "opencv"
cpuTestRuntime group: "org.bytedeco", name: "opencv", classifier: buildTarget cpuTestRuntime group: "org.bytedeco", name: "opencv", classifier: buildTarget
cpuTestRuntime project( path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportRuntimeElements")
/*
cpuTestRuntime(project(":cavis-native:cavis-native-lib")) { cpuTestRuntime(project(":cavis-native:cavis-native-lib")) {
capabilities { capabilities {
it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cpu-support:1.0.0-SNAPSHOT" it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cpu-support:1.0.0-SNAPSHOT"
} }
} }
*/
} }
} }
} }

View File

@ -89,6 +89,7 @@ include ':cavis-native:cavis-native-lib'
include ':cavis-native:cavis-native-common' include ':cavis-native:cavis-native-common'
include ':cavis-dnn' include ':cavis-dnn'
include ':cavis-dnn:cavis-dnn-api' include ':cavis-dnn:cavis-dnn-api'
if(withCuda()) { include ':cavis-dnn:cavis-dnn-cudnn' }
include ':cavis-dnn:cavis-dnn-common' include ':cavis-dnn:cavis-dnn-common'
include ':cavis-dnn:cavis-dnn-common-tests' include ':cavis-dnn:cavis-dnn-common-tests'
include ':cavis-dnn:cavis-dnn-core' include ':cavis-dnn:cavis-dnn-core'
@ -116,6 +117,7 @@ include ':cavis-dnn:cavis-dnn-spark:cavis-dnn-spark-parameterserver'
include ':cavis-dnn:cavis-dnn-tsne' include ':cavis-dnn:cavis-dnn-tsne'
include ':cavis-datavec' include ':cavis-datavec'
include ':cavis-datavec:cavis-datavec-api' include ':cavis-datavec:cavis-datavec-api'
include ':cavis-datavec:dvec-api'
include ':cavis-datavec:cavis-datavec-data' include ':cavis-datavec:cavis-datavec-data'
include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-arrow' include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-arrow'
include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-image' include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-image'
@ -151,3 +153,4 @@ include ':cavis-zoo:cavis-zoo-models'
include ':brutex-extended-tests' include ':brutex-extended-tests'
include ':cavis-full' include ':cavis-full'