Adding cuDNN support
parent
a39e44c782
commit
aab7b423d1
|
@ -19,8 +19,12 @@
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
apply plugin: 'java'
|
plugins {
|
||||||
apply plugin: 'maven-publish'
|
id 'java-library'
|
||||||
|
id 'maven-publish'
|
||||||
|
id 'com.github.johnrengelman.shadow' version '7.1.2'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
||||||
|
|
||||||
|
@ -54,6 +58,7 @@ dependencies {
|
||||||
implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkParameterserver
|
implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkParameterserver
|
||||||
implementation projects.cavisDnn.cavisDnnNnParent.cavisDnnNnCore
|
implementation projects.cavisDnn.cavisDnnNnParent.cavisDnnNnCore
|
||||||
implementation projects.cavisDnn.cavisDnnNn
|
implementation projects.cavisDnn.cavisDnnNn
|
||||||
|
|
||||||
implementation projects.cavisUi.cavisUiCommon
|
implementation projects.cavisUi.cavisUiCommon
|
||||||
implementation projects.cavisUi.cavisUiVertx
|
implementation projects.cavisUi.cavisUiVertx
|
||||||
implementation projects.cavisUi.cavisUiModel
|
implementation projects.cavisUi.cavisUiModel
|
||||||
|
@ -66,11 +71,21 @@ dependencies {
|
||||||
implementation projects.cavisDnn.cavisDnnParallelwrapper
|
implementation projects.cavisDnn.cavisDnnParallelwrapper
|
||||||
|
|
||||||
implementation projects.cavisZoo.cavisZooModels
|
implementation projects.cavisZoo.cavisZooModels
|
||||||
|
|
||||||
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
|
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
test {
|
test {
|
||||||
dependsOn jar
|
enabled true
|
||||||
|
dependsOn shadowJar
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
shadowJar {
|
||||||
|
enabled true;
|
||||||
|
zip64 true //need this to support jars with more than 65535 entries
|
||||||
|
archiveClassifier.set('all')
|
||||||
|
from sourceSets.test.output
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,279 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
public class App {
|
||||||
|
private static final double LEARNING_RATE = 0.0002;
|
||||||
|
private static final double GRADIENT_THRESHOLD = 100.0;
|
||||||
|
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
||||||
|
|
||||||
|
private static JFrame frame;
|
||||||
|
private static JPanel panel;
|
||||||
|
|
||||||
|
private static Layer[] genLayers() {
|
||||||
|
return new Layer[] {
|
||||||
|
new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
|
||||||
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
new DenseLayer.Builder().nIn(256).nOut(512).build(),
|
||||||
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
new DenseLayer.Builder().nIn(512).nOut(1024).build(),
|
||||||
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
|
||||||
|
*
|
||||||
|
* @return config
|
||||||
|
*/
|
||||||
|
private static MultiLayerConfiguration generator() {
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.list(genLayers())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Layer[] disLayers() {
|
||||||
|
return new Layer[]{
|
||||||
|
new DenseLayer.Builder().nIn(784).nOut(1024).build(),
|
||||||
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
new DropoutLayer.Builder(1 - 0.5).build(),
|
||||||
|
new DenseLayer.Builder().nIn(1024).nOut(512).build(),
|
||||||
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
new DropoutLayer.Builder(1 - 0.5).build(),
|
||||||
|
new DenseLayer.Builder().nIn(512).nOut(256).build(),
|
||||||
|
new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
new DropoutLayer.Builder(1 - 0.5).build(),
|
||||||
|
new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private static MultiLayerConfiguration discriminator() {
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.list(disLayers())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static MultiLayerConfiguration gan() {
|
||||||
|
Layer[] genLayers = genLayers();
|
||||||
|
Layer[] disLayers = Arrays.stream(disLayers())
|
||||||
|
.map((layer) -> {
|
||||||
|
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||||
|
return new FrozenLayerWithBackprop(layer);
|
||||||
|
} else {
|
||||||
|
return layer;
|
||||||
|
}
|
||||||
|
}).toArray(Layer[]::new);
|
||||||
|
Layer[] layers = ArrayUtils.addAll(genLayers, disLayers);
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.list(layers)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void runTest() throws Exception {
|
||||||
|
main();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String... args) throws Exception {
|
||||||
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
|
MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 42);
|
||||||
|
|
||||||
|
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
||||||
|
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
||||||
|
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
|
||||||
|
gen.init();
|
||||||
|
dis.init();
|
||||||
|
gan.init();
|
||||||
|
|
||||||
|
copyParams(gen, dis, gan);
|
||||||
|
|
||||||
|
gen.setListeners(new PerformanceListener(10, true));
|
||||||
|
dis.setListeners(new PerformanceListener(10, true));
|
||||||
|
gan.setListeners(new PerformanceListener(10, true));
|
||||||
|
|
||||||
|
trainData.reset();
|
||||||
|
|
||||||
|
int j = 0;
|
||||||
|
for (int i = 0; i < 20; i++) {
|
||||||
|
while (trainData.hasNext()) {
|
||||||
|
j++;
|
||||||
|
|
||||||
|
// generate data
|
||||||
|
INDArray real = trainData.next().getFeatures().muli(2).subi(1);
|
||||||
|
int batchSize = (int) real.shape()[0];
|
||||||
|
|
||||||
|
INDArray fakeIn = Nd4j.rand(batchSize, 100);
|
||||||
|
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||||
|
|
||||||
|
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
||||||
|
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
||||||
|
|
||||||
|
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||||
|
|
||||||
|
dis.fit(data);
|
||||||
|
dis.fit(data);
|
||||||
|
|
||||||
|
// Update the discriminator in the GAN network
|
||||||
|
updateGan(gen, dis, gan);
|
||||||
|
|
||||||
|
gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
|
||||||
|
|
||||||
|
|
||||||
|
if (j % 10 == 1) {
|
||||||
|
System.out.println("Iteration " + j + " Visualizing...");
|
||||||
|
INDArray[] samples = new INDArray[9];
|
||||||
|
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
||||||
|
|
||||||
|
for (int k = 0; k < 9; k++) {
|
||||||
|
INDArray input = fakeSet2.get(k).getFeatures();
|
||||||
|
//samples[k] = gen.output(input, false);
|
||||||
|
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||||
|
|
||||||
|
}
|
||||||
|
visualize(samples);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
trainData.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the GANs generator to gen.
|
||||||
|
updateGen(gen, gan);
|
||||||
|
|
||||||
|
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||||
|
int genLayerCount = gen.getLayers().length;
|
||||||
|
for (int i = 0; i < gan.getLayers().length; i++) {
|
||||||
|
if (i < genLayerCount) {
|
||||||
|
gen.getLayer(i).setParams(gan.getLayer(i).params());
|
||||||
|
} else {
|
||||||
|
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||||
|
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||||
|
gen.getLayer(i).setParams(gan.getLayer(i).params());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||||
|
int genLayerCount = gen.getLayers().length;
|
||||||
|
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||||
|
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).params());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void visualize(INDArray[] samples) {
|
||||||
|
if (frame == null) {
|
||||||
|
frame = new JFrame();
|
||||||
|
frame.setTitle("Viz");
|
||||||
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
frame.setLayout(new BorderLayout());
|
||||||
|
|
||||||
|
panel = new JPanel();
|
||||||
|
|
||||||
|
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
|
||||||
|
frame.add(panel, BorderLayout.CENTER);
|
||||||
|
frame.setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
panel.removeAll();
|
||||||
|
|
||||||
|
for (INDArray sample : samples) {
|
||||||
|
panel.add(getImage(sample));
|
||||||
|
}
|
||||||
|
|
||||||
|
frame.revalidate();
|
||||||
|
frame.pack();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JLabel getImage(INDArray tensor) {
|
||||||
|
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
|
||||||
|
for (int i = 0; i < 784; i++) {
|
||||||
|
int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
|
||||||
|
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
|
||||||
|
}
|
||||||
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
|
||||||
|
|
||||||
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
|
||||||
|
return new JLabel(scaled);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,411 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
|
import org.deeplearning4j.nn.conf.*;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implementation of vanilla Generative Adversarial Networks as introduced in https://arxiv.org/pdf/1406.2661.pdf.
|
||||||
|
* <p>
|
||||||
|
* A DL4J GAN is initialized from two networks: a generator and a discriminator and will build a third network,
|
||||||
|
* the GAN network, from the first two.
|
||||||
|
*
|
||||||
|
* @author Max Pumperla
|
||||||
|
*/
|
||||||
|
public class GAN {
|
||||||
|
private static final IUpdater UPDATER_ZERO = Sgd.builder().learningRate(0.0).build();
|
||||||
|
|
||||||
|
public interface DiscriminatorProvider {
|
||||||
|
MultiLayerNetwork provide(IUpdater updater);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Supplier<MultiLayerNetwork> generatorSupplier;
|
||||||
|
protected DiscriminatorProvider discriminatorSupplier;
|
||||||
|
|
||||||
|
protected MultiLayerNetwork generator;
|
||||||
|
protected MultiLayerNetwork discriminator;
|
||||||
|
protected MultiLayerNetwork gan;
|
||||||
|
protected int latentDim;
|
||||||
|
|
||||||
|
protected IUpdater updater;
|
||||||
|
protected IUpdater biasUpdater;
|
||||||
|
protected OptimizationAlgorithm optimizer;
|
||||||
|
protected GradientNormalization gradientNormalizer;
|
||||||
|
protected double gradientNormalizationThreshold;
|
||||||
|
protected WorkspaceMode trainingWorkSpaceMode;
|
||||||
|
protected WorkspaceMode inferenceWorkspaceMode;
|
||||||
|
protected CacheMode cacheMode;
|
||||||
|
protected long seed;
|
||||||
|
|
||||||
|
private Double[] discriminatorLearningRates;
|
||||||
|
|
||||||
|
|
||||||
|
public GAN(Builder builder) {
|
||||||
|
this.generatorSupplier = builder.generator;
|
||||||
|
this.discriminatorSupplier = builder.discriminator;
|
||||||
|
this.latentDim = builder.latentDimension;
|
||||||
|
this.updater = builder.iUpdater;
|
||||||
|
this.biasUpdater = builder.biasUpdater;
|
||||||
|
this.optimizer = builder.optimizationAlgo;
|
||||||
|
this.gradientNormalizer = builder.gradientNormalization;
|
||||||
|
this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
|
||||||
|
this.trainingWorkSpaceMode = builder.trainingWorkspaceMode;
|
||||||
|
this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode;
|
||||||
|
this.cacheMode = builder.cacheMode;
|
||||||
|
this.seed = builder.seed;
|
||||||
|
|
||||||
|
defineGan();
|
||||||
|
}
|
||||||
|
|
||||||
|
public MultiLayerNetwork getGenerator() {
|
||||||
|
return generator;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MultiLayerNetwork getDiscriminator() {
|
||||||
|
return discriminator;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Evaluation evaluateGan(DataSetIterator data) {
|
||||||
|
return gan.evaluate(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Evaluation evaluateGan(DataSetIterator data, List<String> labelsList) {
|
||||||
|
return gan.evaluate(data, labelsList);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void setGeneratorListeners(BaseTrainingListener[] listeners) {
|
||||||
|
generator.setListeners(listeners);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setDiscriminatorListeners(BaseTrainingListener[] listeners) {
|
||||||
|
discriminator.setListeners(listeners);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setGanListeners(BaseTrainingListener[] listeners) {
|
||||||
|
gan.setListeners(listeners);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void fit(DataSetIterator realData, int numEpochs) {
|
||||||
|
for (int i = 0; i < numEpochs; i++) {
|
||||||
|
while (realData.hasNext()) {
|
||||||
|
// Get real images as features
|
||||||
|
DataSet next = realData.next();
|
||||||
|
fit(next);
|
||||||
|
}
|
||||||
|
realData.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void fit(DataSet next) {
|
||||||
|
int batchSize;
|
||||||
|
INDArray realImages = next.getFeatures().muli(2).subi(1);
|
||||||
|
batchSize = (int) realImages.shape()[0];
|
||||||
|
|
||||||
|
// Sample from latent space and let the generate create fake images.
|
||||||
|
INDArray randomLatentData = Nd4j.rand(new int[]{batchSize, latentDim});
|
||||||
|
INDArray fakeImages = generator.output(randomLatentData);
|
||||||
|
|
||||||
|
// Real images are marked as "0", fake images at "1".
|
||||||
|
DataSet realSet = new DataSet(realImages, Nd4j.zeros(batchSize, 1));
|
||||||
|
DataSet fakeSet = new DataSet(fakeImages, Nd4j.ones(batchSize, 1));
|
||||||
|
|
||||||
|
// Fit the discriminator on a combined batch of real and fake images.
|
||||||
|
DataSet combined = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||||
|
|
||||||
|
/*for (int i = 0; i < discriminator.getLayers().length; i++) {
|
||||||
|
if (discriminatorLearningRates[i] != null) {
|
||||||
|
discriminator.setLearningRate(i, discriminatorLearningRates[i]);
|
||||||
|
}
|
||||||
|
}*/
|
||||||
|
|
||||||
|
discriminator.fit(combined);
|
||||||
|
//discriminator.fit(combined);
|
||||||
|
|
||||||
|
// Update the discriminator in the GAN network
|
||||||
|
updateGanWithDiscriminator();
|
||||||
|
|
||||||
|
// Generate a new set of adversarial examples and try to mislead the discriminator.
|
||||||
|
// by labeling the fake images as real images we reward the generator when it's output
|
||||||
|
// tricks the discriminator.
|
||||||
|
INDArray adversarialExamples = Nd4j.rand(new int[]{batchSize, latentDim});
|
||||||
|
INDArray misleadingLabels = Nd4j.zeros(batchSize, 1);
|
||||||
|
DataSet adversarialSet = new DataSet(adversarialExamples, misleadingLabels);
|
||||||
|
|
||||||
|
// Set learning rate of discriminator part of gan to zero.
|
||||||
|
/*for (int i = generator.getLayers().length; i < gan.getLayers().length; i++) {
|
||||||
|
gan.setLearningRate(i, 0.0);
|
||||||
|
}*/
|
||||||
|
|
||||||
|
// Fit the GAN on the adversarial set, trying to fool the discriminator by generating
|
||||||
|
// better fake images.
|
||||||
|
gan.fit(adversarialSet);
|
||||||
|
|
||||||
|
// Copy the GANs generator part to "generator".
|
||||||
|
updateGeneratorFromGan();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void defineGan() {
|
||||||
|
generator = generatorSupplier.get();
|
||||||
|
generator.init();
|
||||||
|
|
||||||
|
Layer[] genLayers = generator.getLayers();
|
||||||
|
int numGenLayers = genLayers.length;
|
||||||
|
|
||||||
|
discriminator = discriminatorSupplier.provide(updater);
|
||||||
|
discriminator.init();
|
||||||
|
|
||||||
|
MultiLayerNetwork ganDiscriminator = discriminatorSupplier.provide(UPDATER_ZERO);
|
||||||
|
ganDiscriminator.init();
|
||||||
|
|
||||||
|
Layer[] disLayers = ganDiscriminator.getLayers();
|
||||||
|
Layer[] layers = ArrayUtils.addAll(genLayers, disLayers);
|
||||||
|
MultiLayerConfiguration genConf = generator.getLayerWiseConfigurations();
|
||||||
|
MultiLayerConfiguration disConf = ganDiscriminator.getLayerWiseConfigurations();
|
||||||
|
org.deeplearning4j.nn.conf.layers.Layer[] confLayers = new org.deeplearning4j.nn.conf.layers.Layer[layers.length];
|
||||||
|
|
||||||
|
Map<Integer, InputPreProcessor> preProcessors = new HashMap<>();
|
||||||
|
for (int i = 0; i < layers.length; i++) {
|
||||||
|
confLayers[i] = layers[i].conf().getLayer();
|
||||||
|
if (i < numGenLayers) {
|
||||||
|
preProcessors.put(i, genConf.getInputPreProcess(i));
|
||||||
|
} else {
|
||||||
|
preProcessors.put(i, disConf.getInputPreProcess(i - numGenLayers));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiLayerConfiguration ganConf = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(seed)
|
||||||
|
.updater(updater)
|
||||||
|
.biasUpdater(biasUpdater)
|
||||||
|
.optimizationAlgo(optimizer)
|
||||||
|
.gradientNormalization(gradientNormalizer)
|
||||||
|
.gradientNormalizationThreshold(gradientNormalizationThreshold)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.trainingWorkspaceMode(trainingWorkSpaceMode)
|
||||||
|
.inferenceWorkspaceMode(inferenceWorkspaceMode)
|
||||||
|
.cacheMode(cacheMode)
|
||||||
|
.list(confLayers)
|
||||||
|
.inputPreProcessors(preProcessors)
|
||||||
|
.build();
|
||||||
|
gan = new MultiLayerNetwork(ganConf);
|
||||||
|
gan.init();
|
||||||
|
|
||||||
|
// we lose proper init here, need to copy weights after
|
||||||
|
copyParamsToGan();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void copyParamsToGan() {
|
||||||
|
int genLayerCount = generator.getLayers().length;
|
||||||
|
for (int i = 0; i < gan.getLayers().length; i++) {
|
||||||
|
if (i < genLayerCount) {
|
||||||
|
generator.getLayer(i).setParams(gan.getLayer(i).params());
|
||||||
|
} else {
|
||||||
|
discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* After the GAN has been trained on misleading images, we update the generator the
|
||||||
|
* new weights (we don't have to update the discriminator, as it is frozen in the GAN).
|
||||||
|
*/
|
||||||
|
private void updateGeneratorFromGan() {
|
||||||
|
for (int i = 0; i < generator.getLayers().length; i++) {
|
||||||
|
generator.getLayer(i).setParams(gan.getLayer(i).params());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* After the discriminator has been trained, we update the respective parts of the GAN network
|
||||||
|
* as well.
|
||||||
|
*/
|
||||||
|
private void updateGanWithDiscriminator() {
|
||||||
|
int genLayerCount = generator.getLayers().length;
|
||||||
|
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||||
|
gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).params());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* GAN builder, used as a starting point for creating a MultiLayerConfiguration or
|
||||||
|
* ComputationGraphConfiguration.<br>
|
||||||
|
*/
|
||||||
|
public static class Builder implements Cloneable {
|
||||||
|
protected Supplier<MultiLayerNetwork> generator;
|
||||||
|
protected DiscriminatorProvider discriminator;
|
||||||
|
protected int latentDimension;
|
||||||
|
|
||||||
|
protected IUpdater iUpdater = new Sgd();
|
||||||
|
protected IUpdater biasUpdater = null;
|
||||||
|
protected long seed = System.currentTimeMillis();
|
||||||
|
protected OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
|
||||||
|
protected GradientNormalization gradientNormalization = GradientNormalization.None;
|
||||||
|
protected double gradientNormalizationThreshold = 1.0;
|
||||||
|
|
||||||
|
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
|
||||||
|
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
|
||||||
|
protected CacheMode cacheMode = CacheMode.NONE;
|
||||||
|
|
||||||
|
|
||||||
|
public Builder() {
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the (fake) image generator of the GAN.
|
||||||
|
*
|
||||||
|
* @param generator MultilayerNetwork
|
||||||
|
* @return Builder
|
||||||
|
*/
|
||||||
|
public GAN.Builder generator(Supplier<MultiLayerNetwork> generator) {
|
||||||
|
this.generator = generator;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the image discriminator of the GAN.
|
||||||
|
*
|
||||||
|
* @param discriminator MultilayerNetwork
|
||||||
|
* @return Builder
|
||||||
|
*/
|
||||||
|
public GAN.Builder discriminator(DiscriminatorProvider discriminator) {
|
||||||
|
this.discriminator = discriminator;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the latent dimension, i.e. the input vector space dimension of the generator.
|
||||||
|
*
|
||||||
|
* @param latentDimension latent space input dimension.
|
||||||
|
* @return Builder
|
||||||
|
*/
|
||||||
|
public GAN.Builder latentDimension(int latentDimension) {
|
||||||
|
this.latentDimension = latentDimension;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Random number generator seed. Used for reproducibility between runs
|
||||||
|
*/
|
||||||
|
public GAN.Builder seed(long seed) {
|
||||||
|
this.seed = seed;
|
||||||
|
Nd4j.getRandom().setSeed(seed);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimization algorithm to use. Most common: OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT
|
||||||
|
*
|
||||||
|
* @param optimizationAlgo Optimization algorithm to use when training
|
||||||
|
*/
|
||||||
|
public GAN.Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) {
|
||||||
|
this.optimizationAlgo = optimizationAlgo;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam}
|
||||||
|
* or {@link org.nd4j.linalg.learning.config.Nesterovs}<br>
|
||||||
|
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
||||||
|
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
||||||
|
* value, and can be overridden on a per-layer basis.
|
||||||
|
*
|
||||||
|
* @param updater Updater to use
|
||||||
|
*/
|
||||||
|
public GAN.Builder updater(IUpdater updater) {
|
||||||
|
this.iUpdater = updater;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
||||||
|
* set by {@link #updater(IUpdater)}<br>
|
||||||
|
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
||||||
|
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
||||||
|
* value, and can be overridden on a per-layer basis.
|
||||||
|
*
|
||||||
|
* @param updater Updater to use for bias parameters
|
||||||
|
*/
|
||||||
|
public GAN.Builder biasUpdater(IUpdater updater) {
|
||||||
|
this.biasUpdater = updater;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc.
|
||||||
|
* See {@link GradientNormalization} for details<br>
|
||||||
|
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
||||||
|
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
||||||
|
* value, and can be overridden on a per-layer basis.
|
||||||
|
*
|
||||||
|
* @param gradientNormalization Type of normalization to use. Defaults to None.
|
||||||
|
* @see GradientNormalization
|
||||||
|
*/
|
||||||
|
public GAN.Builder gradientNormalization(GradientNormalization gradientNormalization) {
|
||||||
|
this.gradientNormalization = gradientNormalization;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
|
||||||
|
* GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue<br>
|
||||||
|
* Not used otherwise.<br>
|
||||||
|
* L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping.<br>
|
||||||
|
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
||||||
|
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
||||||
|
* value, and can be overridden on a per-layer basis.
|
||||||
|
*/
|
||||||
|
public GAN.Builder gradientNormalizationThreshold(double threshold) {
|
||||||
|
this.gradientNormalizationThreshold = threshold;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public GAN build() {
|
||||||
|
return new GAN(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,73 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
|
||||||
|
public class GANVisualizationUtils {
|
||||||
|
|
||||||
|
public static JFrame initFrame() {
|
||||||
|
JFrame frame = new JFrame();
|
||||||
|
frame.setTitle("Viz");
|
||||||
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
frame.setLayout(new BorderLayout());
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static JPanel initPanel(JFrame frame, int numSamples) {
|
||||||
|
JPanel panel = new JPanel();
|
||||||
|
|
||||||
|
panel.setLayout(new GridLayout(numSamples / 3, 1, 8, 8));
|
||||||
|
frame.add(panel, BorderLayout.CENTER);
|
||||||
|
frame.setVisible(true);
|
||||||
|
return panel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void visualize(INDArray[] samples, JFrame frame, JPanel panel) {
|
||||||
|
panel.removeAll();
|
||||||
|
|
||||||
|
for (int i = 0; i < samples.length; i++) {
|
||||||
|
panel.add(getImage(samples[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
frame.revalidate();
|
||||||
|
frame.pack();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JLabel getImage(INDArray tensor) {
|
||||||
|
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
|
||||||
|
for (int i = 0; i < 784; i++) {
|
||||||
|
int pixel = (int) (((tensor.getDouble(i) + 1) * 2) * 255);
|
||||||
|
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
|
||||||
|
}
|
||||||
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
|
||||||
|
|
||||||
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
|
||||||
|
return new JLabel(scaled);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,193 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Training and visualizing a deep convolutional generative adversarial network (DCGAN) on handwritten digits.
|
||||||
|
*
|
||||||
|
* @author Max Pumperla, wmeddie
|
||||||
|
*/
|
||||||
|
public class MnistDCGANExample {
|
||||||
|
|
||||||
|
private static JFrame frame;
|
||||||
|
private static JPanel panel;
|
||||||
|
|
||||||
|
private static final int latentDim = 100;
|
||||||
|
private static final int height = 28;
|
||||||
|
private static final int width = 28;
|
||||||
|
private static final int channels = 1;
|
||||||
|
|
||||||
|
|
||||||
|
private static void visualize(INDArray[] samples) {
|
||||||
|
if (frame == null) {
|
||||||
|
frame = new JFrame();
|
||||||
|
frame.setTitle("Viz");
|
||||||
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
frame.setLayout(new BorderLayout());
|
||||||
|
|
||||||
|
panel = new JPanel();
|
||||||
|
|
||||||
|
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
|
||||||
|
frame.add(panel, BorderLayout.CENTER);
|
||||||
|
frame.setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
panel.removeAll();
|
||||||
|
|
||||||
|
for (int i = 0; i < samples.length; i++) {
|
||||||
|
panel.add(getImage(samples[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
frame.revalidate();
|
||||||
|
frame.pack();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JLabel getImage(INDArray tensor) {
|
||||||
|
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
|
||||||
|
for (int i = 0; i < 784; i++) {
|
||||||
|
int pixel = (int) (((tensor.getDouble(i) + 1) * 2) * 255);
|
||||||
|
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
|
||||||
|
}
|
||||||
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
|
||||||
|
|
||||||
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
|
||||||
|
return new JLabel(scaled);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) throws Exception {
|
||||||
|
Supplier<MultiLayerNetwork> genSupplier = () -> {
|
||||||
|
return new MultiLayerNetwork(new NeuralNetConfiguration.Builder().list()
|
||||||
|
.layer(0, new DenseLayer.Builder().nIn(latentDim).nOut(width / 2 * height / 2 * 128)
|
||||||
|
.activation(Activation.LEAKYRELU).weightInit(WeightInit.NORMAL).build())
|
||||||
|
.layer(1, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5)
|
||||||
|
.convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
|
||||||
|
// Up-sampling to 28x28x256
|
||||||
|
.layer(2, new Deconvolution2D.Builder().nIn(128).nOut(128).stride(2, 2)
|
||||||
|
.kernelSize(5, 5).convolutionMode(ConvolutionMode.Same)
|
||||||
|
.activation(Activation.LEAKYRELU).build())
|
||||||
|
.layer(3, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5)
|
||||||
|
.convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
|
||||||
|
.layer(4, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5)
|
||||||
|
.convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
|
||||||
|
.layer(5, new Convolution2D.Builder().nIn(128).nOut(channels).kernelSize(7, 7)
|
||||||
|
.convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build())
|
||||||
|
.layer(6, new ActivationLayer.Builder().activation(Activation.TANH).build())
|
||||||
|
.inputPreProcessor(1,
|
||||||
|
new FeedForwardToCnnPreProcessor(height / 2, width / 2, 128))
|
||||||
|
.inputPreProcessor(6, new CnnToFeedForwardPreProcessor(height, width, channels))
|
||||||
|
.setInputType(InputType.feedForward(latentDim))
|
||||||
|
.build());
|
||||||
|
};
|
||||||
|
|
||||||
|
GAN.DiscriminatorProvider discriminatorProvider = (updater) -> {
|
||||||
|
return new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
|
||||||
|
.updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
|
||||||
|
//.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
|
||||||
|
//.gradientNormalizationThreshold(100.0)
|
||||||
|
.list()
|
||||||
|
.layer(0, new Convolution2D.Builder().nIn(channels).nOut(64).kernelSize(3, 3)
|
||||||
|
.activation(Activation.LEAKYRELU).build())
|
||||||
|
.layer(1, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2)
|
||||||
|
.activation(Activation.LEAKYRELU).build())
|
||||||
|
.layer(2, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2)
|
||||||
|
.activation(Activation.LEAKYRELU).build())
|
||||||
|
.layer(3, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2)
|
||||||
|
.activation(Activation.LEAKYRELU).build())
|
||||||
|
.layer(4, new DropoutLayer.Builder().dropOut(0.5).build())
|
||||||
|
.layer(5, new DenseLayer.Builder().nIn(64 * 2 * 2).nOut(1).activation(Activation.SIGMOID).build())
|
||||||
|
.layer(6, new LossLayer.Builder().lossFunction(LossFunctions.LossFunction.XENT).build())
|
||||||
|
.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(height, width, channels))
|
||||||
|
.inputPreProcessor(4, new CnnToFeedForwardPreProcessor(2, 2, 64))
|
||||||
|
.setInputType(InputType.convolutionalFlat(height, width, channels))
|
||||||
|
.build());
|
||||||
|
};
|
||||||
|
|
||||||
|
GAN gan = new GAN.Builder()
|
||||||
|
.generator(genSupplier)
|
||||||
|
.discriminator(discriminatorProvider)
|
||||||
|
.latentDimension(latentDim)
|
||||||
|
//.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
|
||||||
|
//.gradientNormalizationThreshold(1.0)
|
||||||
|
.updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
gan.getGenerator().setListeners(new PerformanceListener(1, true));
|
||||||
|
gan.getDiscriminator().setListeners(new PerformanceListener(1, true));
|
||||||
|
|
||||||
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
|
int batchSize = 64;
|
||||||
|
MnistDataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 42);
|
||||||
|
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
//gan.fit(trainData, 1);
|
||||||
|
|
||||||
|
System.out.println("Starting epoch: " + (i + 1));
|
||||||
|
|
||||||
|
trainData.reset();
|
||||||
|
int j = 0;
|
||||||
|
while (trainData.hasNext()) {
|
||||||
|
DataSet next = trainData.next();
|
||||||
|
gan.fit(next);
|
||||||
|
|
||||||
|
if (j % 1 == 0) {
|
||||||
|
System.out.println("Epoch " + (i + 1) + " iteration " + j + " Visualizing...");
|
||||||
|
INDArray fakeIn = Nd4j.rand(new int[]{batchSize, latentDim});
|
||||||
|
|
||||||
|
INDArray[] samples = new INDArray[9];
|
||||||
|
for (int k = 0; k < 9; k++) {
|
||||||
|
samples[k] = gan.getGenerator().output(fakeIn.getRow(k), false);
|
||||||
|
}
|
||||||
|
visualize(samples);
|
||||||
|
}
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println("Finished epoch: " + (i + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,146 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Relatively small GAN example using only Dense layers with dropout to generate handwritten
|
||||||
|
* digits from MNIST data.
|
||||||
|
*/
|
||||||
|
public class MnistSimpleGAN {
|
||||||
|
|
||||||
|
private static final int LATENT_DIM = 100;
|
||||||
|
|
||||||
|
private static final double LEARNING_RATE = 0.0002;
|
||||||
|
private static final IUpdater UPDATER_ZERO = Sgd.builder().learningRate(0.0).build();
|
||||||
|
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
||||||
|
|
||||||
|
|
||||||
|
public static MultiLayerNetwork getGenerator() {
|
||||||
|
MultiLayerConfiguration genConf = new NeuralNetConfiguration.Builder()
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.list()
|
||||||
|
.layer(new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build())
|
||||||
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
|
.layer(new DenseLayer.Builder().nIn(256).nOut(512).build())
|
||||||
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
|
.layer(new DenseLayer.Builder().nIn(512).nOut(1024).build())
|
||||||
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
|
.layer(new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build())
|
||||||
|
.build();
|
||||||
|
return new MultiLayerNetwork(genConf);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static MultiLayerNetwork getDiscriminator(IUpdater updater) {
|
||||||
|
MultiLayerConfiguration discConf = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(42)
|
||||||
|
.updater(updater)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.list()
|
||||||
|
.layer(new DenseLayer.Builder().nIn(784).nOut(1024).updater(updater).build())
|
||||||
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
|
.layer(new DropoutLayer.Builder(1 - 0.5).build())
|
||||||
|
.layer(new DenseLayer.Builder().nIn(1024).nOut(512).updater(updater).build())
|
||||||
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
|
.layer(new DropoutLayer.Builder(1 - 0.5).build())
|
||||||
|
.layer(new DenseLayer.Builder().nIn(512).nOut(256).updater(updater).build())
|
||||||
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
|
.layer(new DropoutLayer.Builder(1 - 0.5).build())
|
||||||
|
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1)
|
||||||
|
.activation(Activation.SIGMOID).updater(updater).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return new MultiLayerNetwork(discConf);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) throws Exception {
|
||||||
|
GAN gan = new GAN.Builder()
|
||||||
|
.generator(MnistSimpleGAN::getGenerator)
|
||||||
|
.discriminator(MnistSimpleGAN::getDiscriminator)
|
||||||
|
.latentDimension(LATENT_DIM)
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
|
int batchSize = 128;
|
||||||
|
MnistDataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 42);
|
||||||
|
|
||||||
|
|
||||||
|
// Sample from latent space once to visualize progress on image generation.
|
||||||
|
int numSamples = 9;
|
||||||
|
JFrame frame = GANVisualizationUtils.initFrame();
|
||||||
|
JPanel panel = GANVisualizationUtils.initPanel(frame, numSamples);
|
||||||
|
|
||||||
|
for (int i = 0; i < 100; i++) {
|
||||||
|
trainData.reset();
|
||||||
|
int j = 0;
|
||||||
|
while (trainData.hasNext()) {
|
||||||
|
gan.fit(trainData.next());
|
||||||
|
//gan.fit(trainData, 1);
|
||||||
|
|
||||||
|
if (j % 10 == 0) {
|
||||||
|
INDArray fakeIn = Nd4j.rand(new int[]{batchSize, LATENT_DIM});
|
||||||
|
System.out.println("Epoch " + (i + 1) + " Iteration " + j + " Visualizing...");
|
||||||
|
INDArray[] samples = new INDArray[numSamples];
|
||||||
|
for (int k = 0; k < numSamples; k++) {
|
||||||
|
INDArray input = fakeIn.getRow(k);
|
||||||
|
samples[k] = gan.getGenerator().output(input, false);
|
||||||
|
}
|
||||||
|
GANVisualizationUtils.visualize(samples, frame, panel);
|
||||||
|
}
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,39 +20,87 @@
|
||||||
|
|
||||||
package net.brutex.spark;
|
package net.brutex.spark;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.EnumSet;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.io.FileUtils;
|
||||||
|
import org.apache.hadoop.conf.Configuration;
|
||||||
|
import org.apache.hadoop.fs.CreateFlag;
|
||||||
|
import org.apache.hadoop.fs.FSDataOutputStream;
|
||||||
|
import org.apache.hadoop.fs.FileContext;
|
||||||
|
import org.apache.hadoop.fs.FileSystem;
|
||||||
|
import org.apache.hadoop.fs.FileUtil;
|
||||||
|
import org.apache.hadoop.fs.Options.CreateOpts;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
|
import org.apache.spark.SparkContext;
|
||||||
import org.apache.spark.sql.SparkSession;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.junit.jupiter.api.AfterAll;
|
import org.junit.jupiter.api.AfterAll;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseSparkSessionTest implements Serializable {
|
public abstract class BaseSparkSessionTest implements Serializable {
|
||||||
|
|
||||||
private static SparkSession spark;
|
private static SparkSession spark;
|
||||||
|
|
||||||
public static SparkSession getSession() {
|
public static SparkSession getSession() {
|
||||||
|
final String jarPath = uploadToHdfs("./build/libs/brutex-extended-tests-1.0.0-SNAPSHOT-all.jar");
|
||||||
|
|
||||||
SparkConf sparkConf = new SparkConf()
|
SparkConf sparkConf = new SparkConf()
|
||||||
.setMaster("spark://10.5.5.200:7077")
|
.setMaster("spark://10.5.5.200:7077")
|
||||||
.setAppName(BaseSparkSessionTest.class.getSimpleName())
|
.setAppName(BaseSparkSessionTest.class.getSimpleName())
|
||||||
.set("spark.driver.bindAddress", "10.5.5.145")
|
.set("spark.driver.bindAddress", "10.5.5.145")
|
||||||
|
.set("spark.blockManager.port", "65001")
|
||||||
|
//.set("spark.driver.bindAddress", "0.0.0.0")
|
||||||
.set("spark.network.timeout", "240000")
|
.set("spark.network.timeout", "240000")
|
||||||
.set("spark.driver.host", "10.5.5.145")
|
.set("spark.driver.host", "10.5.5.145")
|
||||||
.set("spark.deploy.mode", "client")
|
.set("spark.deploy.mode", "cluster")
|
||||||
.set("spark.executor.memory", "4g")
|
.set("spark.executor.memory", "4g")
|
||||||
.set("spark.cores.max", "4")
|
.set("spark.cores.max", "4")
|
||||||
.set("spark.worker.cleanup.enabled", "true")
|
.set("spark.worker.cleanup.enabled", "true")
|
||||||
.set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
|
.set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
|
||||||
.set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
|
.set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
|
||||||
.set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000");
|
.set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000")
|
||||||
|
//.set("spark.jars", jarPath)
|
||||||
|
;
|
||||||
spark = SparkSession.builder()
|
spark = SparkSession.builder()
|
||||||
.config(sparkConf)
|
.config(sparkConf)
|
||||||
.getOrCreate();
|
.getOrCreate();
|
||||||
|
|
||||||
|
spark.sparkContext().addJar(jarPath);
|
||||||
return spark;
|
return spark;
|
||||||
}
|
}
|
||||||
|
public static String uploadToHdfs(String jarFile) {
|
||||||
|
File f = new File(jarFile);
|
||||||
|
if(!f.exists() && !f.isFile()) throw new RuntimeException("File to upload does not exist.");
|
||||||
|
final String base = "hdfs://10.5.5.200:9000/";
|
||||||
|
String targetPath = "/user/brian/" + f.getName();
|
||||||
|
try {
|
||||||
|
Configuration conf = new Configuration();
|
||||||
|
|
||||||
|
//FileContext hdfs = FileContext.getFileContext(URI.create(base), conf);
|
||||||
|
org.apache.hadoop.fs.FileSystem hdfs = FileSystem.get(URI.create(base), conf);
|
||||||
|
//String file = SparkFiles.get("phpMawTba");
|
||||||
|
|
||||||
|
org.apache.hadoop.fs.Path target = new org.apache.hadoop.fs.Path(targetPath);
|
||||||
|
|
||||||
|
try {
|
||||||
|
hdfs.delete(target, false);
|
||||||
|
} catch (Exception e) {};
|
||||||
|
|
||||||
|
FileUtil.copy(f, hdfs, target, false, conf);
|
||||||
|
//Apache Commons
|
||||||
|
//FileUtils.copyFile(f, fTarget);
|
||||||
|
} catch(IOException ioe) {
|
||||||
|
ioe.printStackTrace();
|
||||||
|
}
|
||||||
|
return base + targetPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
public static void beforeAll() {
|
public static void beforeAll() {
|
||||||
|
@ -64,4 +112,11 @@ public abstract class BaseSparkSessionTest implements Serializable {
|
||||||
getSession().close();
|
getSession().close();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSessionCreation() {
|
||||||
|
SparkSession session = getSession();
|
||||||
|
log.info("Spark {} session id: {}", session.version(), session.sessionUUID());
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,22 +20,34 @@
|
||||||
*/
|
*/
|
||||||
package net.brutex.spark;
|
package net.brutex.spark;
|
||||||
|
|
||||||
import com.fasterxml.jackson.core.Version;
|
import java.io.IOException;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.fs.FileSystem;
|
import org.apache.hadoop.fs.FileSystem;
|
||||||
import org.apache.hadoop.fs.Path;
|
import org.apache.hadoop.fs.Path;
|
||||||
import org.apache.spark.SparkConf;
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.api.java.function.ForeachFunction;
|
||||||
import org.apache.spark.api.java.function.Function;
|
import org.apache.spark.api.java.function.Function;
|
||||||
|
import org.apache.spark.sql.Dataset;
|
||||||
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.RowFactory;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
|
import org.apache.spark.sql.types.Metadata;
|
||||||
|
import org.apache.spark.sql.types.StringType;
|
||||||
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
import org.apache.spark.sql.types.StructType;
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.transform.TransformProcess;
|
import org.datavec.api.transform.TransformProcess;
|
||||||
import org.datavec.api.transform.filter.FilterInvalidValues;
|
import org.datavec.api.transform.filter.FilterInvalidValues;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.Writable;
|
import org.datavec.api.Writable;
|
||||||
|
import org.datavec.spark.transform.Normalization;
|
||||||
import org.datavec.spark.transform.SparkTransformExecutor;
|
import org.datavec.spark.transform.SparkTransformExecutor;
|
||||||
import org.datavec.spark.transform.misc.StringToWritablesFunction;
|
import org.datavec.spark.transform.misc.StringToWritablesFunction;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator.Set;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
@ -47,7 +59,6 @@ import org.deeplearning4j.spark.api.TrainingMaster;
|
||||||
import org.deeplearning4j.spark.datavec.DataVecDataSetFunction;
|
import org.deeplearning4j.spark.datavec.DataVecDataSetFunction;
|
||||||
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
||||||
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
||||||
import org.deeplearning4j.ui.api.UIServer;
|
|
||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -56,7 +67,6 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.nio.file.Paths;
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -70,24 +80,77 @@ import java.util.Random;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||||
@Tag("integration")
|
@Tag("integration")
|
||||||
public class BrianTest /*extends BaseDL4JTest*/ {
|
public class BrianTest extends BaseSparkSessionTest {
|
||||||
|
/*
|
||||||
static {
|
static {
|
||||||
String OS = System.getProperty("os.name").toLowerCase();
|
String OS = System.getProperty("os.name").toLowerCase();
|
||||||
|
|
||||||
if (OS.contains("win")) {
|
if (OS.contains("win")) {
|
||||||
System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
|
System.setProperty("hadoop.home.dir",
|
||||||
|
Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
|
||||||
} else {
|
} else {
|
||||||
System.setProperty("hadoop.home.dir", "/");
|
System.setProperty("hadoop.home.dir", "/");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 400000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
private JavaSparkContext sc;
|
private JavaSparkContext sc;
|
||||||
private JavaRDD<String> rdd;
|
private JavaRDD<String> rdd;
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void wrapEmnitDataset() throws IOException, InterruptedException {
|
||||||
|
SparkSession sc = getSession();
|
||||||
|
EmnistDataSetIterator dataset = new EmnistDataSetIterator(Set.BALANCED, 128, true);
|
||||||
|
DataSet ds = dataset.next();
|
||||||
|
System.out.println( "Number of features " + ds.numInputs());
|
||||||
|
System.out.println( "Number of samples " + ds.numExamples());
|
||||||
|
System.out.println( "Outcomes " + ds.numOutcomes());
|
||||||
|
final String oppsFile = uploadToHdfs("c:/temp/opps.csv");
|
||||||
|
|
||||||
|
//System.out.println( "Reading file from " + oppsFile);
|
||||||
|
|
||||||
|
JavaRDD<String> rdd = sc.sparkContext().textFile(oppsFile, 1)
|
||||||
|
.toJavaRDD();
|
||||||
|
System.out.println("Count " + rdd.count());
|
||||||
|
//while(true) Thread.sleep(1000);
|
||||||
|
|
||||||
|
//rdd.foreach( s -> {
|
||||||
|
// System.out.println("* "+s);
|
||||||
|
// });
|
||||||
|
|
||||||
|
|
||||||
|
//JavaRDD<String> rdd2 = rdd.flatMap( s -> Arrays.asList( s.split(";")).iterator() );
|
||||||
|
//rdd2.collect().forEach( a -> System.out.print("# " + a + " ") );
|
||||||
|
|
||||||
|
StructType struct = new StructType(Arrays.asList(
|
||||||
|
StructField.apply("stage", DataTypes.StringType, false, Metadata.empty()),
|
||||||
|
StructField.apply("period", DataTypes.StringType, false, Metadata.empty()),
|
||||||
|
StructField.apply("portfolio", DataTypes.StringType, false, Metadata.empty()),
|
||||||
|
StructField.apply("country", DataTypes.StringType, false, Metadata.empty()),
|
||||||
|
StructField.apply("lfr", DataTypes.StringType, false, Metadata.empty()),
|
||||||
|
StructField.apply("saas", DataTypes.StringType, false, Metadata.empty())
|
||||||
|
).toArray(new StructField[]{})
|
||||||
|
);
|
||||||
|
JavaRDD<Row> rdd3 = rdd.map( attributes -> RowFactory.create(attributes.split(";")));
|
||||||
|
|
||||||
|
Dataset<Row> frame = sc.createDataFrame(rdd3, struct);
|
||||||
|
Dataset<Row> frame2 = frame.select(frame.col("lfr").cast(DataTypes.FloatType));
|
||||||
|
frame.show(200);
|
||||||
|
|
||||||
|
// frame.collect().map(row -> System.out.println(row.fieldIndex("stage") + row.fieldIndex("country")));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
//frame.agg( frame.col("stage"), frame.col("lfr"));
|
||||||
|
frame.foreach((ForeachFunction<Row>) s -> System.out.println(s));
|
||||||
|
|
||||||
|
//sc.read().csv(rdd2);
|
||||||
|
//Normalization normalization = Normalization.zeromeanUnitVariance()
|
||||||
|
//sc.
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
public void loadData() {
|
public void loadData() {
|
||||||
|
@ -109,71 +172,6 @@ public class BrianTest /*extends BaseDL4JTest*/ {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@BeforeAll
|
|
||||||
public void setUp() throws Exception {
|
|
||||||
log.info("Running @BeforeEach scope");
|
|
||||||
System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
|
|
||||||
Version version = com.fasterxml.jackson.databind.cfg.PackageVersion.VERSION;
|
|
||||||
System.out.println("Jackson version found: " + version);
|
|
||||||
SparkConf sparkConf = new SparkConf()
|
|
||||||
.setMaster("spark://10.5.5.200:7077")
|
|
||||||
.setAppName("Brian3")
|
|
||||||
.set("spark.driver.bindAddress", "10.5.5.145")
|
|
||||||
.set("spark.network.timeout", "240000")
|
|
||||||
.set("spark.driver.host", "10.5.5.145")
|
|
||||||
.set("spark.driver.bindAddress", "10.5.5.145")
|
|
||||||
.set("spark.deploy.mode", "cluster")
|
|
||||||
.set("spark.executor.memory", "2g")
|
|
||||||
.set("spark.executor.cores", "2")
|
|
||||||
.set("spark.cores.max", "4")
|
|
||||||
.set("spark.worker.cleanup.enabled", "false")
|
|
||||||
.set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
|
|
||||||
.set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
|
|
||||||
.set("spark.driver.extraClassPath", "brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar;brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar")
|
|
||||||
.set("spark.executor.extraClassPath", "brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar;brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar")
|
|
||||||
.set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000");
|
|
||||||
//.set("spark.driver.cores", "2")
|
|
||||||
//.set("spark.driver.memory", "8g")
|
|
||||||
//.set("spark.driver.host", "10.5.5.145")
|
|
||||||
//.setExecutorEnv("spark.executor.cores", "2")
|
|
||||||
//.setExecutorEnv("spark.executor.memory", "2g")
|
|
||||||
//.set("spark.submit.deployMode", "client")
|
|
||||||
|
|
||||||
/*
|
|
||||||
SparkSession spark = SparkSession
|
|
||||||
.builder()
|
|
||||||
.master("spark://10.5.5.200:7077")
|
|
||||||
.config("spark.driver.bindAddress", "10.5.5.145")
|
|
||||||
.config("spark.driver.host", "10.5.5.145")
|
|
||||||
//.config("spark.driver.memory", "5g")
|
|
||||||
.appName("BrianTest2")
|
|
||||||
.getOrCreate();
|
|
||||||
*/
|
|
||||||
sc = new JavaSparkContext(sparkConf);
|
|
||||||
|
|
||||||
// sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\deeplearning4j\\deeplearning4j-scaleout\\spark\\dl4j-spark-nlp-java8\\target\\dl4j-spark-nlp-java8_2.12-1.0.0-SNAPSHOT-tests.jar");
|
|
||||||
// sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\datavec\\datavec-api\\target\\datavec-api-1.0.0-SNAPSHOT.jar");
|
|
||||||
// sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\nd4j\\nd4j-uberjar\\target\\nd4j-uberjar-1.0.0-SNAPSHOT.jar");
|
|
||||||
// sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\nd4j\\nd4j-common\\target\\nd4j-common-1.0.0-SNAPSHOT.jar");
|
|
||||||
// sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\datavec\\datavec-spark\\target\\datavec-spark_2.12-1.0.0-SNAPSHOT.jar");
|
|
||||||
sc.addJar("C:\\Users\\brian\\_projects\\Brian-Spark-DL4J-Tests\\target\\brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar");
|
|
||||||
sc.addJar("C:\\Users\\brian\\_projects\\Brian-Spark-DL4J-Tests\\target\\brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar");
|
|
||||||
|
|
||||||
|
|
||||||
rdd = sc.textFile("hdfs://10.5.5.200:9000/user/zeppelin/cities_full.csv.gz");
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterAll
|
|
||||||
public void tearDown() throws Exception {
|
|
||||||
sc.close();
|
|
||||||
sc.stop();
|
|
||||||
UIServer.stopInstance();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
////@Ignore("AB 2019/05/21 - Failing - Issue #7657")
|
////@Ignore("AB 2019/05/21 - Failing - Issue #7657")
|
||||||
|
@ -193,25 +191,23 @@ public class BrianTest /*extends BaseDL4JTest*/ {
|
||||||
@Test
|
@Test
|
||||||
public void testSchemaCreation() throws Exception {
|
public void testSchemaCreation() throws Exception {
|
||||||
|
|
||||||
|
|
||||||
rdd.cache();
|
rdd.cache();
|
||||||
|
|
||||||
JavaRDD<String> cities = rdd.map( (Function<String, String>) line -> {
|
JavaRDD<String> cities = rdd.map((Function<String, String>) line -> {
|
||||||
return line.split(",")[1];
|
return line.split(",")[1];
|
||||||
}).cache();
|
}).cache();
|
||||||
|
|
||||||
JavaRDD<String> stateCodeList = rdd.map( (Function<String, String>) line -> {
|
JavaRDD<String> stateCodeList = rdd.map((Function<String, String>) line -> {
|
||||||
return line.split(",")[2];
|
return line.split(",")[2];
|
||||||
}).cache();
|
}).cache();
|
||||||
|
|
||||||
JavaRDD<String> countryCodeList = rdd.map( (Function<String, String>) line -> {
|
JavaRDD<String> countryCodeList = rdd.map((Function<String, String>) line -> {
|
||||||
return line.split(",")[3];
|
return line.split(",")[3];
|
||||||
}).cache();
|
}).cache();
|
||||||
|
|
||||||
|
|
||||||
CSVRecordReader recordReader = new CSVRecordReader(0, ',');
|
CSVRecordReader recordReader = new CSVRecordReader(0, ',');
|
||||||
JavaRDD<List<Writable>> convertedRDD = rdd.map((Function<String, List<Writable>>) s -> {
|
JavaRDD<List<Writable>> convertedRDD = rdd.map((Function<String, List<Writable>>) s -> {
|
||||||
return new StringToWritablesFunction( recordReader).call(s);
|
return new StringToWritablesFunction(recordReader).call(s);
|
||||||
});
|
});
|
||||||
|
|
||||||
//Source Schema
|
//Source Schema
|
||||||
|
@ -252,16 +248,18 @@ public class BrianTest /*extends BaseDL4JTest*/ {
|
||||||
processedData.cache();
|
processedData.cache();
|
||||||
//log.info("Datenmenge nach processing: " + processedData.count());
|
//log.info("Datenmenge nach processing: " + processedData.count());
|
||||||
|
|
||||||
|
|
||||||
//Vectorisieren
|
//Vectorisieren
|
||||||
int labelIndex = 0; //in welcher Spalte ist das Label
|
int labelIndex = 0; //in welcher Spalte ist das Label
|
||||||
int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte
|
int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte
|
||||||
|
|
||||||
DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels, false);
|
DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels,
|
||||||
|
false);
|
||||||
JavaRDD<DataSet> rddDataSet = processedData.map(datavecFunction);
|
JavaRDD<DataSet> rddDataSet = processedData.map(datavecFunction);
|
||||||
log.info("rddDataset: " + rddDataSet.toDebugString());
|
log.info("rddDataset: " + rddDataSet.toDebugString());
|
||||||
Random rand = new Random();
|
Random rand = new Random();
|
||||||
rddDataSet.sortBy( (Function<DataSet, Double>) s -> {return rand.nextDouble(); }, true, 8);
|
rddDataSet.sortBy((Function<DataSet, Double>) s -> {
|
||||||
|
return rand.nextDouble();
|
||||||
|
}, true, 8);
|
||||||
|
|
||||||
//og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect());
|
//og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect());
|
||||||
|
|
||||||
|
@ -281,7 +279,8 @@ public class BrianTest /*extends BaseDL4JTest*/ {
|
||||||
//Create Trainingmaster
|
//Create Trainingmaster
|
||||||
|
|
||||||
TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4)
|
TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4)
|
||||||
.rddTrainingApproach(RDDTrainingApproach.Direct) //when "export", tries to save everything first
|
.rddTrainingApproach(
|
||||||
|
RDDTrainingApproach.Direct) //when "export", tries to save everything first
|
||||||
.batchSizePerWorker(1000)
|
.batchSizePerWorker(1000)
|
||||||
.collectTrainingStats(true)
|
.collectTrainingStats(true)
|
||||||
.build();
|
.build();
|
||||||
|
@ -292,23 +291,26 @@ public class BrianTest /*extends BaseDL4JTest*/ {
|
||||||
.seed(123)
|
.seed(123)
|
||||||
.updater(new Nesterovs(0.1, 0.9))
|
.updater(new Nesterovs(0.1, 0.9))
|
||||||
.list()
|
.list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build())
|
.layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER)
|
||||||
.layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
.activation(Activation.RELU).l2(0.001).build())
|
||||||
|
.layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.RELU).build())
|
||||||
//.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
//.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
||||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4).weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4)
|
||||||
|
.weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//Define SparkNet
|
//Define SparkNet
|
||||||
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration, trainingMaster);
|
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration,
|
||||||
|
trainingMaster);
|
||||||
|
|
||||||
|
JavaRDD<DataSet>[] split = rddDataSet.randomSplit(new double[]{0.9, 0.1}, 123);
|
||||||
JavaRDD<DataSet>[] split = rddDataSet.randomSplit(new double[] {0.9, 0.1}, 123);
|
|
||||||
//JavaRDD<DataSet> trainingData = split[0];
|
//JavaRDD<DataSet> trainingData = split[0];
|
||||||
JavaRDD<DataSet> trainingData = rddDataSet;
|
JavaRDD<DataSet> trainingData = rddDataSet;
|
||||||
JavaRDD<DataSet> testData = split[1];
|
JavaRDD<DataSet> testData = split[1];
|
||||||
|
|
||||||
//Run Training on subset
|
//Run Training on subset
|
||||||
for(int i =0; i<20; i++) {
|
for (int i = 0; i < 20; i++) {
|
||||||
sparkNet.fit(trainingData);
|
sparkNet.fit(trainingData);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,7 +328,7 @@ public class BrianTest /*extends BaseDL4JTest*/ {
|
||||||
Evaluation eval = new Evaluation(4); // outputNum = 10: number of output classes
|
Evaluation eval = new Evaluation(4); // outputNum = 10: number of output classes
|
||||||
Iterator<DataSet> iter = testData.toLocalIterator();
|
Iterator<DataSet> iter = testData.toLocalIterator();
|
||||||
log.info("testData has " + testData.count() + " DataSets");
|
log.info("testData has " + testData.count() + " DataSets");
|
||||||
while(iter.hasNext()){
|
while (iter.hasNext()) {
|
||||||
DataSet next = iter.next();
|
DataSet next = iter.next();
|
||||||
//log.info("getFeatures " + next.getFeatures() );
|
//log.info("getFeatures " + next.getFeatures() );
|
||||||
INDArray output = finalNet.output(next.getFeatures()); //get the networks prediction
|
INDArray output = finalNet.output(next.getFeatures()); //get the networks prediction
|
||||||
|
|
|
@ -25,8 +25,8 @@ ext {
|
||||||
|
|
||||||
def flatbuffers = [version: "1.10.0"]
|
def flatbuffers = [version: "1.10.0"]
|
||||||
|
|
||||||
def spark = [version: "3.1.2"]
|
def spark = [version: "3.2.2"]
|
||||||
def scala = [version:"2.12.10"] //[version:"2.13.5"]
|
def scala = [version:"2.12.15"] //[version:"2.13.5"]
|
||||||
|
|
||||||
def netty = [version: "4.1.68.Final"]
|
def netty = [version: "4.1.68.Final"]
|
||||||
|
|
||||||
|
|
|
@ -21,59 +21,44 @@
|
||||||
|
|
||||||
package net.brutex.cavis.dvec.api;
|
package net.brutex.cavis.dvec.api;
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.nio.Buffer;
|
import java.nio.Buffer;
|
||||||
import java.nio.LongBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.util.List;
|
|
||||||
import net.brutex.cavis.dvec.api.exceptions.DVecException;
|
import net.brutex.cavis.dvec.api.exceptions.DVecException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A Field can be considered a "column" in a {@code Record}, as such a Field may refer to multiple
|
* Abtract implementation of the Field interface {@see FieldInterface}, that handles all data storage
|
||||||
* entries of that "column". Fields are typed as Buffers. Some of them defined in the dvec core api,
|
* in memory and adds basic error handling.
|
||||||
* other (i.e. Image or Arrow) require dvec extensions accordingly.
|
|
||||||
*
|
*
|
||||||
* @author Brian Rosenberger
|
* @author Brian Rosenberger
|
||||||
* @since 1.0
|
* @since 1.0
|
||||||
*/
|
*/
|
||||||
public interface FieldInterface<T extends Buffer> extends Serializable {
|
public abstract class AbstractField<T extends Buffer> implements Field<T> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get a reference to the metadata for this Field.
|
* {@inheritDoc}
|
||||||
*
|
|
||||||
* @return the {@link FieldMetadata}
|
|
||||||
*/
|
|
||||||
FieldMetadata getFieldMetadata();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the 1st field as Buffer. This deserializes the data from the underlying storage.
|
|
||||||
*
|
|
||||||
* @return T underlying Buffer
|
|
||||||
*/
|
|
||||||
default T read() throws DVecException {
|
|
||||||
return read(0, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get a range of fields as a {@code Buffer}
|
|
||||||
*
|
*
|
||||||
* @param start Index of starting position, zero based
|
* @param start Index of starting position, zero based
|
||||||
* @param length how many fields to read
|
* @param length how many fields to read
|
||||||
* @return the buffers
|
* @return the list of Buffer
|
||||||
*/
|
*/
|
||||||
T read(long start, long length) throws DVecException;
|
@Override
|
||||||
|
public T read(long start, long length) throws DVecException {
|
||||||
/**
|
if (start<0 || start>internalStorage.capacity()-1 ) {
|
||||||
* Write the data into the underlying storage.
|
throw new DVecException("Read on Field start position is out of bounds.");
|
||||||
*/
|
}
|
||||||
default void write(T buffer) {
|
if (start+length> internalStorage.capacity()) {
|
||||||
write(0, buffer);
|
throw new DVecException("Read on Field exceeds field length");
|
||||||
|
}
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
@Override
|
||||||
* Write the data into the underyling storage starting at a position
|
public void write(long pos, T buffer) {
|
||||||
*
|
|
||||||
* @param pos the position to start
|
}
|
||||||
*/
|
|
||||||
void write(long pos, T buffer);
|
private ByteBuffer internalStorage = null;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -21,46 +21,57 @@
|
||||||
|
|
||||||
package net.brutex.cavis.dvec.api;
|
package net.brutex.cavis.dvec.api;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.nio.Buffer;
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import net.brutex.cavis.dvec.api.exceptions.DVecException;
|
import net.brutex.cavis.dvec.api.exceptions.DVecException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Abtract implementation of the Field interface {@see FieldInterface}, that handles all data storage
|
* A Field can be considered a "column" in a {@code Record}, as such a Field may refer to multiple
|
||||||
* in memory and adds basic error handling.
|
* entries of that "column". Fields are typed as Buffers. Some of them defined in the dvec core api,
|
||||||
|
* other (i.e. Image or Arrow) require dvec extensions accordingly.
|
||||||
*
|
*
|
||||||
* @author Brian Rosenberger
|
* @author Brian Rosenberger
|
||||||
* @since 1.0
|
* @since 1.0
|
||||||
*/
|
*/
|
||||||
public abstract class Field<T extends Buffer> implements FieldInterface<T> {
|
public interface Field<T extends Buffer> extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* Get a reference to the metadata for this Field.
|
||||||
|
*
|
||||||
|
* @return the {@link FieldMetadata}
|
||||||
|
*/
|
||||||
|
FieldMetadata getFieldMetadata();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the 1st field as Buffer. This deserializes the data from the underlying storage.
|
||||||
|
*
|
||||||
|
* @return T underlying Buffer
|
||||||
|
*/
|
||||||
|
default T read() throws DVecException {
|
||||||
|
return read(0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a range of fields as a {@code Buffer}
|
||||||
*
|
*
|
||||||
* @param start Index of starting position, zero based
|
* @param start Index of starting position, zero based
|
||||||
* @param length how many fields to read
|
* @param length how many fields to read
|
||||||
* @return the list of Buffer
|
* @return the buffers
|
||||||
*/
|
*/
|
||||||
@Override
|
T read(long start, long length) throws DVecException;
|
||||||
public T read(long start, long length) throws DVecException {
|
|
||||||
if (start<0 || start>internalStorage.capacity()-1 ) {
|
/**
|
||||||
throw new DVecException("Read on Field start position is out of bounds.");
|
* Write the data into the underlying storage.
|
||||||
}
|
*/
|
||||||
if (start+length> internalStorage.capacity()) {
|
default void write(T buffer) {
|
||||||
throw new DVecException("Read on Field exceeds field length");
|
write(0, buffer);
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
/**
|
||||||
public void write(long pos, T buffer) {
|
* Write the data into the underyling storage starting at a position
|
||||||
|
*
|
||||||
}
|
* @param pos the position to start
|
||||||
|
*/
|
||||||
private ByteBuffer internalStorage = null;
|
void write(long pos, T buffer);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -4877,7 +4877,7 @@ public class Nd4j {
|
||||||
* Create an ndarray of zeros
|
* Create an ndarray of zeros
|
||||||
*
|
*
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return an ndarray with ones filled in
|
* @return an ndarray with zeros filled in
|
||||||
*/
|
*/
|
||||||
public static INDArray zeros(int[] shape, char order) {
|
public static INDArray zeros(int[] shape, char order) {
|
||||||
checkShapeValues(shape);
|
checkShapeValues(shape);
|
||||||
|
@ -4896,7 +4896,7 @@ public class Nd4j {
|
||||||
* Create an ndarray of zeros
|
* Create an ndarray of zeros
|
||||||
*
|
*
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return an ndarray with ones filled in
|
* @return an ndarray with zeros filled in
|
||||||
*/
|
*/
|
||||||
public static INDArray zeros(@NonNull int... shape) {
|
public static INDArray zeros(@NonNull int... shape) {
|
||||||
return Nd4j.create(shape);
|
return Nd4j.create(shape);
|
||||||
|
@ -4907,7 +4907,7 @@ public class Nd4j {
|
||||||
* Create an ndarray of zeros
|
* Create an ndarray of zeros
|
||||||
*
|
*
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return an ndarray with ones filled in
|
* @return an ndarray with zeros filled in
|
||||||
*/
|
*/
|
||||||
public static INDArray zeros(@NonNull long... shape) {
|
public static INDArray zeros(@NonNull long... shape) {
|
||||||
return Nd4j.create(shape);
|
return Nd4j.create(shape);
|
||||||
|
|
|
@ -99,9 +99,14 @@ public class DL4JClassLoading {
|
||||||
.asSubclass(superclass)
|
.asSubclass(superclass)
|
||||||
.getDeclaredConstructor(parameterTypes)
|
.getDeclaredConstructor(parameterTypes)
|
||||||
.newInstance(args);
|
.newInstance(args);
|
||||||
} catch (InstantiationException | IllegalAccessException | InvocationTargetException
|
} catch (InstantiationException | IllegalAccessException
|
||||||
| NoSuchMethodException instantiationException) {
|
| NoSuchMethodException instantiationException) {
|
||||||
log.error(String.format("Cannot create instance of class '%s'.", className), instantiationException);
|
log.error(String.format("Cannot create instance of class '%s'.", className), instantiationException);
|
||||||
|
|
||||||
|
throw new RuntimeException(instantiationException);
|
||||||
|
} catch (InvocationTargetException instantiationException) {
|
||||||
|
log.error(String.format("---------- ----------- ---------- \nInvocationTargetException was '%s'.", instantiationException.getTargetException().getMessage()), instantiationException);
|
||||||
|
log.error(String.format("java.library.path was '%s'\n---------- ---------- ----------", System.getProperty("java.library.path")));
|
||||||
throw new RuntimeException(instantiationException);
|
throw new RuntimeException(instantiationException);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
|
||||||
|
apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
||||||
|
|
||||||
|
ext {
|
||||||
|
buildTarget = rootProject.ext.buildTarget
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation platform(projects.cavisCommonPlatform)
|
||||||
|
implementation project(":cavis-native:cavis-native-jcublas")
|
||||||
|
implementation projects.cavisDnn.cavisDnnApi
|
||||||
|
implementation projects.cavisDnn.cavisDnnNn
|
||||||
|
|
||||||
|
implementation group: "org.bytedeco", name: "cuda"
|
||||||
|
implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget
|
||||||
|
implementation group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
|
||||||
|
|
||||||
|
implementation group: "org.bytedeco", name: "javacpp"
|
||||||
|
implementation group: "org.bytedeco", name: "javacpp", classifier: buildTarget
|
||||||
|
|
||||||
|
implementation 'com.jakewharton.byteunits:byteunits:0.9.1'
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,252 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.cuda;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.bytedeco.javacpp.*;
|
||||||
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import org.bytedeco.cuda.cudnn.*;
|
||||||
|
import static org.bytedeco.cuda.global.cudart.*;
|
||||||
|
import static org.bytedeco.cuda.global.cudnn.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Functionality shared by all cuDNN-based helpers.
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public abstract class BaseCudnnHelper {
|
||||||
|
|
||||||
|
/* public BaseCudnnHelper() {
|
||||||
|
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
protected static void checkCuda(int error) {
|
||||||
|
if (error != cudaSuccess) {
|
||||||
|
throw new RuntimeException("CUDA error = " + error + ": " + cudaGetErrorString(error).getString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static void checkCudnn(int status) {
|
||||||
|
if (status != CUDNN_STATUS_SUCCESS) {
|
||||||
|
throw new RuntimeException("cuDNN status = " + status + ": " + cudnnGetErrorString(status).getString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static class CudnnContext extends cudnnContext {
|
||||||
|
|
||||||
|
protected static class Deallocator extends CudnnContext implements Pointer.Deallocator {
|
||||||
|
Deallocator(CudnnContext c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deallocate() {
|
||||||
|
destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public CudnnContext() {
|
||||||
|
// insure that cuDNN initializes on the same device as ND4J for this thread
|
||||||
|
Nd4j.create(1);
|
||||||
|
AtomicAllocator.getInstance();
|
||||||
|
// This needs to be called in subclasses:
|
||||||
|
// createHandles();
|
||||||
|
// deallocator(new Deallocator(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
public CudnnContext(CudnnContext c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void createHandles() {
|
||||||
|
checkCudnn(cudnnCreate(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void destroyHandles() {
|
||||||
|
checkCudnn(cudnnDestroy(this));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static class DataCache extends Pointer {
|
||||||
|
|
||||||
|
static class Deallocator extends DataCache implements Pointer.Deallocator {
|
||||||
|
Deallocator(DataCache c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deallocate() {
|
||||||
|
checkCuda(cudaFree(this));
|
||||||
|
setNull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class HostDeallocator extends DataCache implements Pointer.Deallocator {
|
||||||
|
HostDeallocator(DataCache c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deallocate() {
|
||||||
|
checkCuda(cudaFreeHost(this));
|
||||||
|
setNull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public DataCache() {}
|
||||||
|
|
||||||
|
public DataCache(long size) {
|
||||||
|
position = 0;
|
||||||
|
limit = capacity = size;
|
||||||
|
int error = cudaMalloc(this, size);
|
||||||
|
if (error != cudaSuccess) {
|
||||||
|
log.warn("Cannot allocate " + size + " bytes of device memory (CUDA error = " + error
|
||||||
|
+ "), proceeding with host memory");
|
||||||
|
checkCuda(cudaMallocHost(this, size));
|
||||||
|
deallocator(new HostDeallocator(this));
|
||||||
|
} else {
|
||||||
|
deallocator(new Deallocator(this));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public DataCache(DataCache c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static class TensorArray extends PointerPointer<cudnnTensorStruct> {
|
||||||
|
|
||||||
|
static class Deallocator extends TensorArray implements Pointer.Deallocator {
|
||||||
|
Pointer owner;
|
||||||
|
|
||||||
|
Deallocator(TensorArray a, Pointer owner) {
|
||||||
|
this.address = a.address;
|
||||||
|
this.capacity = a.capacity;
|
||||||
|
this.owner = owner;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deallocate() {
|
||||||
|
for (int i = 0; !isNull() && i < capacity; i++) {
|
||||||
|
cudnnTensorStruct t = this.get(cudnnTensorStruct.class, i);
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(t));
|
||||||
|
}
|
||||||
|
if (owner != null) {
|
||||||
|
owner.deallocate();
|
||||||
|
owner = null;
|
||||||
|
}
|
||||||
|
setNull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public TensorArray() {}
|
||||||
|
|
||||||
|
public TensorArray(long size) {
|
||||||
|
PointerPointer p = new PointerPointer(size);
|
||||||
|
p.deallocate(false);
|
||||||
|
this.address = p.address();
|
||||||
|
this.limit = p.limit();
|
||||||
|
this.capacity = p.capacity();
|
||||||
|
|
||||||
|
cudnnTensorStruct t = new cudnnTensorStruct();
|
||||||
|
for (int i = 0; i < capacity; i++) {
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(t));
|
||||||
|
this.put(i, t);
|
||||||
|
}
|
||||||
|
deallocator(new Deallocator(this, p));
|
||||||
|
}
|
||||||
|
|
||||||
|
public TensorArray(TensorArray a) {
|
||||||
|
super(a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected final DataType nd4jDataType;
|
||||||
|
protected final int dataType;
|
||||||
|
protected final int dataTypeSize;
|
||||||
|
// both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta
|
||||||
|
protected final Pointer alpha;
|
||||||
|
protected final Pointer beta;
|
||||||
|
protected SizeTPointer sizeInBytes = new SizeTPointer(1);
|
||||||
|
|
||||||
|
public BaseCudnnHelper(@NonNull DataType dataType){
|
||||||
|
this.nd4jDataType = dataType;
|
||||||
|
this.dataType = dataType == DataType.DOUBLE ? CUDNN_DATA_DOUBLE
|
||||||
|
: dataType == DataType.FLOAT ? CUDNN_DATA_FLOAT : CUDNN_DATA_HALF;
|
||||||
|
this.dataTypeSize = dataType == DataType.DOUBLE ? 8 : dataType == DataType.FLOAT ? 4 : 2;
|
||||||
|
// both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta
|
||||||
|
this.alpha = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(1.0) : new FloatPointer(1.0f);
|
||||||
|
this.beta = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(0.0) : new FloatPointer(0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int toCudnnDataType(DataType type){
|
||||||
|
switch (type){
|
||||||
|
case DOUBLE:
|
||||||
|
return CUDNN_DATA_DOUBLE;
|
||||||
|
case FLOAT:
|
||||||
|
return CUDNN_DATA_FLOAT;
|
||||||
|
case INT:
|
||||||
|
return CUDNN_DATA_INT32;
|
||||||
|
case HALF:
|
||||||
|
return CUDNN_DATA_HALF;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Cannot convert type: " + type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean checkSupported() {
|
||||||
|
// add general checks here, if any
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* From CuDNN documentation -
|
||||||
|
* "Tensors are restricted to having at least 4 dimensions... When working with lower dimensional data, it is
|
||||||
|
* recommended that the user create a 4Dtensor, and set the size along unused dimensions to 1."
|
||||||
|
*
|
||||||
|
* This method implements that - basically appends 1s to the end (shape or stride) to make it length 4,
|
||||||
|
* or leaves it unmodified if the length is already 4 or more.
|
||||||
|
* This method can be used for both shape and strides
|
||||||
|
*
|
||||||
|
* @param shapeOrStrides
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
protected static int[] adaptForTensorDescr(int[] shapeOrStrides){
|
||||||
|
if(shapeOrStrides.length >= 4)
|
||||||
|
return shapeOrStrides;
|
||||||
|
int[] out = new int[4];
|
||||||
|
int i=0;
|
||||||
|
for(; i<shapeOrStrides.length; i++ ){
|
||||||
|
out[i] = shapeOrStrides[i];
|
||||||
|
}
|
||||||
|
for(; i<4; i++ ){
|
||||||
|
out[i] = 1;
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,758 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.cuda.convolution;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
|
import com.jakewharton.byteunits.BinaryByteUnit;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdFilterAlgo;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.FwdAlgo;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
||||||
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
|
||||||
|
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||||
|
import org.deeplearning4j.util.ConvolutionUtils;
|
||||||
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
||||||
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
import org.nd4j.common.util.OneTimeLogger;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.bytedeco.cuda.cudart.*;
|
||||||
|
import org.bytedeco.cuda.cudnn.*;
|
||||||
|
|
||||||
|
import static org.bytedeco.cuda.global.cudnn.*;
|
||||||
|
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
|
||||||
|
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* cuDNN-based helper for the convolution layer.
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class CudnnConvolutionHelper extends BaseCudnnHelper implements ConvolutionHelper {
|
||||||
|
|
||||||
|
public CudnnConvolutionHelper(DataType dataType) {
|
||||||
|
super(dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class CudnnConvolutionContext extends CudnnContext {
|
||||||
|
|
||||||
|
private static class Deallocator extends CudnnConvolutionContext implements Pointer.Deallocator {
|
||||||
|
Deallocator(CudnnConvolutionContext c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deallocate() {
|
||||||
|
destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
|
||||||
|
biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct();
|
||||||
|
private cudnnFilterStruct filterDesc = new cudnnFilterStruct();
|
||||||
|
private cudnnConvolutionStruct convDesc = new cudnnConvolutionStruct();
|
||||||
|
private cudnnActivationStruct activationDesc = new cudnnActivationStruct();
|
||||||
|
|
||||||
|
public CudnnConvolutionContext() {
|
||||||
|
createHandles();
|
||||||
|
deallocator(new Deallocator(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
public CudnnConvolutionContext(CudnnConvolutionContext c) {
|
||||||
|
super(c);
|
||||||
|
srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
|
||||||
|
dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
|
||||||
|
biasTensorDesc = new cudnnTensorStruct(c.biasTensorDesc);
|
||||||
|
deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
|
||||||
|
filterDesc = new cudnnFilterStruct(c.filterDesc);
|
||||||
|
convDesc = new cudnnConvolutionStruct(c.convDesc);
|
||||||
|
activationDesc = new cudnnActivationStruct(c.activationDesc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void createHandles() {
|
||||||
|
super.createHandles();
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(biasTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateFilterDescriptor(filterDesc));
|
||||||
|
checkCudnn(cudnnCreateConvolutionDescriptor(convDesc));
|
||||||
|
checkCudnn(cudnnCreateActivationDescriptor(activationDesc));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void destroyHandles() {
|
||||||
|
checkCudnn(cudnnDestroyActivationDescriptor(activationDesc));
|
||||||
|
checkCudnn(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||||
|
checkCudnn(cudnnDestroyFilterDescriptor(filterDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(biasTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
|
||||||
|
super.destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private CudnnConvolutionContext cudnnContext = new CudnnConvolutionContext();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel,
|
||||||
|
int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn,
|
||||||
|
AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo,
|
||||||
|
ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
|
||||||
|
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
|
||||||
|
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
|
||||||
|
//Therefore: all computation here is done in NCHW format only
|
||||||
|
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
|
||||||
|
boolean origNHWC = false;
|
||||||
|
if(format == CNN2DFormat.NHWC){
|
||||||
|
input = input.permute(0,3,1,2); //NHWC to NCHW
|
||||||
|
delta = delta.permute(0,3,1,2);
|
||||||
|
origNHWC = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
|
||||||
|
|
||||||
|
int code;
|
||||||
|
|
||||||
|
val miniBatch = input.size(0);
|
||||||
|
val outDepth = weights.size(0);
|
||||||
|
val inDepth = weights.size(1);
|
||||||
|
val kH = weights.size(2);
|
||||||
|
val kW = weights.size(3);
|
||||||
|
|
||||||
|
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
|
||||||
|
input = args.getInput();
|
||||||
|
val inH = input.size(2);
|
||||||
|
val inW = input.size(3);
|
||||||
|
val srcStride = input.stride();
|
||||||
|
val outSize = args.getOutSize();
|
||||||
|
val outH = outSize[0];
|
||||||
|
val outW = outSize[1];
|
||||||
|
|
||||||
|
if (!Shape.strideDescendingCAscendingF(delta)) {
|
||||||
|
// apparently not supported by cuDNN
|
||||||
|
delta = delta.dup();
|
||||||
|
}
|
||||||
|
|
||||||
|
val deltaStride = delta.stride();
|
||||||
|
int[] algo1 = new int[1];
|
||||||
|
int[] algo2 = new int[1];
|
||||||
|
|
||||||
|
|
||||||
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
|
code = cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth,(int) inH, (int) inW,
|
||||||
|
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]);
|
||||||
|
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
code = cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) outDepth, (int) outH, (int) outW,
|
||||||
|
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]);
|
||||||
|
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
|
||||||
|
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
|
||||||
|
checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW);
|
||||||
|
checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
if (mode == AlgoMode.USER_SPECIFIED && bwdFilterAlgo != null && bwdDataAlgo != null) {
|
||||||
|
switch (bwdFilterAlgo) {
|
||||||
|
case ALGO_0:
|
||||||
|
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
|
||||||
|
break;
|
||||||
|
case ALGO_1:
|
||||||
|
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
||||||
|
break;
|
||||||
|
case FFT:
|
||||||
|
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT;
|
||||||
|
break;
|
||||||
|
case ALGO_3:
|
||||||
|
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3;
|
||||||
|
break;
|
||||||
|
case WINOGRAD:
|
||||||
|
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD;
|
||||||
|
break;
|
||||||
|
case WINOGRAD_NONFUSED:
|
||||||
|
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED;
|
||||||
|
break;
|
||||||
|
case FFT_TILING:
|
||||||
|
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING;
|
||||||
|
break;
|
||||||
|
case COUNT:
|
||||||
|
algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("Unknown BwdFilterAlgo: " + bwdFilterAlgo);
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (bwdDataAlgo) {
|
||||||
|
case ALGO_0:
|
||||||
|
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
|
||||||
|
break;
|
||||||
|
case ALGO_1:
|
||||||
|
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
||||||
|
break;
|
||||||
|
case FFT:
|
||||||
|
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT;
|
||||||
|
break;
|
||||||
|
case FFT_TILING:
|
||||||
|
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING;
|
||||||
|
break;
|
||||||
|
case WINOGRAD:
|
||||||
|
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD;
|
||||||
|
break;
|
||||||
|
case WINOGRAD_NONFUSED:
|
||||||
|
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED;
|
||||||
|
break;
|
||||||
|
case COUNT:
|
||||||
|
algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("Unknown BwdDataAlgo: " + bwdDataAlgo);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
/*
|
||||||
|
code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
|
||||||
|
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
|
||||||
|
: CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
|
||||||
|
0, algo1);
|
||||||
|
*/
|
||||||
|
val fa = new cudnnConvolutionBwdFilterAlgoPerf_t();
|
||||||
|
val counts = new int[1];
|
||||||
|
code = cudnnFindConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, 1, counts, fa);
|
||||||
|
algo1[0] = fa.algo();
|
||||||
|
|
||||||
|
checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
/*
|
||||||
|
code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
|
||||||
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
|
||||||
|
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
|
||||||
|
: CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
|
||||||
|
0, algo2);
|
||||||
|
*/
|
||||||
|
|
||||||
|
val da = new cudnnConvolutionBwdDataAlgoPerf_t();
|
||||||
|
code = cudnnFindConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
|
||||||
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, 1, counts, da);
|
||||||
|
|
||||||
|
algo2[0] = da.algo();
|
||||||
|
checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(log.isTraceEnabled()){
|
||||||
|
BwdFilterAlgo fa = BwdFilterAlgo.values()[algo1[0]];
|
||||||
|
BwdDataAlgo da = BwdDataAlgo.values()[algo2[0]];
|
||||||
|
log.trace("CudnnConvolutionHelper backward algorithm selection: mode {}, filter algorithm {}, data algorithm {}", mode, fa, da);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray epsNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, weights.dataType(), new long[] {(int) miniBatch,(int) inDepth, (int) inH, (int) inW}, 'c');
|
||||||
|
|
||||||
|
val dstStride = epsNext.stride();
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView,
|
||||||
|
biasGradView, delta, epsNext);
|
||||||
|
Pointer srcData = allocator.getPointer(input, context);
|
||||||
|
Pointer filterData = allocator.getPointer(weights, context);
|
||||||
|
Pointer filterGradData = allocator.getPointer(weightGradView, context);
|
||||||
|
Pointer biasGradData = allocator.getPointer(biasGradView, context);
|
||||||
|
Pointer deltaData = allocator.getPointer(delta, context);
|
||||||
|
Pointer dstData = allocator.getPointer(epsNext, context);
|
||||||
|
|
||||||
|
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
|
||||||
|
checkCudnn(false, "cudnnSetStream", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
|
||||||
|
(int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]);
|
||||||
|
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0],
|
||||||
|
sizeInBytes);
|
||||||
|
checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
long sizeInBytes1 = sizeInBytes.get(0);
|
||||||
|
code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc,
|
||||||
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0],
|
||||||
|
sizeInBytes);
|
||||||
|
checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
|
||||||
|
long sizeInBytes2 = sizeInBytes.get(0);
|
||||||
|
if (workSpace == null || sizeInBytes1 > workSpace.capacity() || sizeInBytes2 > workSpace.capacity()) {
|
||||||
|
long newSize = Math.max(sizeInBytes1, sizeInBytes2);
|
||||||
|
if(log.isTraceEnabled()){
|
||||||
|
if(workSpace == null){
|
||||||
|
log.trace("CudnnConvolutionHelper backpropGradient: Allocating initial workspace of size {} ({})", newSize,
|
||||||
|
BinaryByteUnit.format(newSize, "#.00"));
|
||||||
|
} else {
|
||||||
|
log.trace("CudnnConvolutionHelper backpropGradient: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})",
|
||||||
|
workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"),
|
||||||
|
newSize, BinaryByteUnit.format(newSize, "#.00"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(workSpace != null)
|
||||||
|
workSpace.deallocate();
|
||||||
|
workSpace = new DataCache(newSize);
|
||||||
|
workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
|
||||||
|
}
|
||||||
|
|
||||||
|
code = cudnnSetTensor4dDescriptor(cudnnContext.biasTensorDesc, TENSOR_FORMAT, dataType, 1, (int) outDepth, 1, 1);
|
||||||
|
checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta,
|
||||||
|
cudnnContext.biasTensorDesc, biasGradData);
|
||||||
|
checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
|
||||||
|
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace,
|
||||||
|
workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData);
|
||||||
|
checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData,
|
||||||
|
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace,
|
||||||
|
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
|
||||||
|
checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView,
|
||||||
|
delta, epsNext);
|
||||||
|
|
||||||
|
Gradient retGradient = new DefaultGradient();
|
||||||
|
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView);
|
||||||
|
retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c');
|
||||||
|
|
||||||
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
|
context.syncOldStream();
|
||||||
|
|
||||||
|
//Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
|
||||||
|
// we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
|
||||||
|
if(args.isManualPadBottom() || args.isManualPadRight()) {
|
||||||
|
epsNext = epsNext.get(all(), all(),
|
||||||
|
interval(0, epsNext.size(2) - (args.isManualPadBottom() ? 1 : 0)),
|
||||||
|
interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if(origNHWC){
|
||||||
|
epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Pair<>(retGradient, epsNext);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad,
|
||||||
|
AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format,
|
||||||
|
LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
|
||||||
|
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
|
||||||
|
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
|
||||||
|
//Therefore: all computation here is done in NCHW format only
|
||||||
|
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
|
||||||
|
boolean origNHWC = false;
|
||||||
|
if(format == CNN2DFormat.NHWC){
|
||||||
|
input = input.permute(0,3,1,2); //NHWC to NCHW
|
||||||
|
origNHWC = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
|
||||||
|
|
||||||
|
int code;
|
||||||
|
|
||||||
|
val miniBatch = input.size(0);
|
||||||
|
val outDepth = weights.size(0);
|
||||||
|
val inDepth = weights.size(1);
|
||||||
|
val kH = weights.size(2);
|
||||||
|
val kW = weights.size(3);
|
||||||
|
|
||||||
|
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
|
||||||
|
input = args.getInput();
|
||||||
|
val inH = input.size(2);
|
||||||
|
val inW = input.size(3);
|
||||||
|
val srcStride = input.stride();
|
||||||
|
val outSize = args.getOutSize();
|
||||||
|
|
||||||
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
|
INDArray z = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, weights.dataType(), new long[] {(int) miniBatch, (int) outDepth, outSize[0], outSize[1]});
|
||||||
|
|
||||||
|
code = cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
|
||||||
|
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]);
|
||||||
|
checkCudnn(true, "cudnnSetTensor4dDescriptorEx", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW);
|
||||||
|
checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
|
||||||
|
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
|
||||||
|
checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
|
||||||
|
// find dimension of convolution output
|
||||||
|
// checkCudnn(cudnnGetConvolution2dForwardOutputDim(cudnnContext.convDesc, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, n, c, h, w));
|
||||||
|
// INDArray z = Nd4j.createUninitialized(new int[]{n[0],c[0],h[0],w[0]},'c');
|
||||||
|
|
||||||
|
|
||||||
|
int[] algo = new int[1];
|
||||||
|
val dstStride = z.stride();
|
||||||
|
code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) outDepth, (int) outSize[0],
|
||||||
|
(int) outSize[1], (int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]);
|
||||||
|
checkCudnn(true, "cudnnSetTensor4dDescriptorEx", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
if (mode == AlgoMode.USER_SPECIFIED && fwdAlgo != null) {
|
||||||
|
switch (fwdAlgo) {
|
||||||
|
case IMPLICIT_GEMM:
|
||||||
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
|
||||||
|
break;
|
||||||
|
case IMPLICIT_PRECOMP_GEMM:
|
||||||
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
||||||
|
break;
|
||||||
|
case GEMM:
|
||||||
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_GEMM;
|
||||||
|
break;
|
||||||
|
case DIRECT:
|
||||||
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_DIRECT;
|
||||||
|
break;
|
||||||
|
case FFT:
|
||||||
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_FFT;
|
||||||
|
break;
|
||||||
|
case FFT_TILING:
|
||||||
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING;
|
||||||
|
break;
|
||||||
|
case WINOGRAD:
|
||||||
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD;
|
||||||
|
break;
|
||||||
|
case WINOGRAD_NONFUSED:
|
||||||
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED;
|
||||||
|
break;
|
||||||
|
case COUNT:
|
||||||
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("Unknown FwdAlgo: " + fwdAlgo);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
/*
|
||||||
|
code = cudnnGetConvolutionForwardAlgorithm_v7(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
|
cudnnContext.filterDesc, cudnnContext.convDesc,
|
||||||
|
cudnnContext.dstTensorDesc, mode == AlgoMode.NO_WORKSPACE
|
||||||
|
? CUDNN_CONVOLUTION_FWD_ : CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
|
||||||
|
0, algo);
|
||||||
|
*/
|
||||||
|
|
||||||
|
val cdf = new cudnnConvolutionFwdAlgoPerf_t();
|
||||||
|
val count = new int[1];
|
||||||
|
code = cudnnFindConvolutionForwardAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, 1, count, cdf);
|
||||||
|
|
||||||
|
if(code != CUDNN_STATUS_SUCCESS){
|
||||||
|
//If CuDNN can't infer algorithm - try IMPLICIT_GEMM
|
||||||
|
//Why this specifically? According to the docs, it seems to have the least number of restrictions
|
||||||
|
// to things like dilation
|
||||||
|
|
||||||
|
OneTimeLogger.warn(log, "Error getting CuDNN forward algorithm - falling back on IMPLICIT_GEMM");
|
||||||
|
mode = AlgoMode.USER_SPECIFIED;
|
||||||
|
fwdAlgo = FwdAlgo.IMPLICIT_GEMM;
|
||||||
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
|
||||||
|
}
|
||||||
|
|
||||||
|
algo[0] = cdf.algo();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(log.isTraceEnabled()){
|
||||||
|
FwdAlgo a = FwdAlgo.values()[algo[0]];
|
||||||
|
log.trace("CudnnConvolutionHelper forward algorithm selection: mode {}, algorithm {}", mode, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareAction(z, input, weights, bias);
|
||||||
|
Pointer srcData = allocator.getPointer(input, context);
|
||||||
|
Pointer filterData = allocator.getPointer(weights, context);
|
||||||
|
Pointer biasData = allocator.getPointer(bias, context);
|
||||||
|
Pointer dstData = allocator.getPointer(z, context);
|
||||||
|
|
||||||
|
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
|
||||||
|
checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
|
cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0],
|
||||||
|
sizeInBytes);
|
||||||
|
checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
|
||||||
|
if (workSpace == null || sizeInBytes.get(0) > workSpace.capacity()) {
|
||||||
|
if(log.isTraceEnabled()){
|
||||||
|
if(workSpace == null){
|
||||||
|
log.trace("CudnnConvolutionHelper preOutput: allocating initial workspace of size {} ({})",
|
||||||
|
sizeInBytes.get(), BinaryByteUnit.format(sizeInBytes.get(), "#.00"));
|
||||||
|
} else {
|
||||||
|
log.trace("CudnnConvolutionHelper preOutput: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})",
|
||||||
|
workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"),
|
||||||
|
sizeInBytes.get(), BinaryByteUnit.format(sizeInBytes.get(), "#.00"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(workSpace != null)
|
||||||
|
workSpace.deallocate();
|
||||||
|
workSpace = new DataCache(sizeInBytes.get(0));
|
||||||
|
workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
|
||||||
|
}
|
||||||
|
code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
|
||||||
|
cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace,
|
||||||
|
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
|
||||||
|
checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
|
||||||
|
code = cudnnSetTensor4dDescriptor(cudnnContext.biasTensorDesc, TENSOR_FORMAT, dataType, 1, (int) outDepth, 1, 1);
|
||||||
|
checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha,
|
||||||
|
cudnnContext.dstTensorDesc, dstData);
|
||||||
|
checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
allocator.registerAction(context, z, input, weights, bias);
|
||||||
|
|
||||||
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
|
context.syncOldStream();
|
||||||
|
|
||||||
|
if(origNHWC){
|
||||||
|
z = z.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
}
|
||||||
|
|
||||||
|
return z;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkCudnn(boolean forward, String step, int code, INDArray input, INDArray weights, INDArray bias, INDArray delta,
|
||||||
|
int[] kernel, int[] strides, int[] pad,
|
||||||
|
AlgoMode mode, FwdAlgo fwdAlgo, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, int[] dilation) {
|
||||||
|
|
||||||
|
if (code != CUDNN_STATUS_SUCCESS) {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
sb.append("CuDNN error = ").append(code).append(": ").append(cudnnGetErrorString(code).getString())
|
||||||
|
.append(" during ")
|
||||||
|
.append(forward ? "forward pass" : "backward pass")
|
||||||
|
.append(" - step ").append(step)
|
||||||
|
.append(": inputShape=").append(Arrays.toString(input.shape()))
|
||||||
|
.append(", weightsShape=").append(Arrays.toString(weights.shape()))
|
||||||
|
.append(", biasShape=").append(bias == null ? null : Arrays.toString(bias.shape()));
|
||||||
|
if (!forward) {
|
||||||
|
sb.append(", gradientShape=").append(Arrays.toString(delta.shape()));
|
||||||
|
}
|
||||||
|
sb.append(", kernel=").append(Arrays.toString(kernel))
|
||||||
|
.append(", stride=").append(Arrays.toString(strides))
|
||||||
|
.append(", padding=").append(Arrays.toString(pad))
|
||||||
|
.append(", dilation=").append(Arrays.toString(dilation))
|
||||||
|
.append(", AlgoMode=").append(mode);
|
||||||
|
if (forward) {
|
||||||
|
sb.append(", fwdAlgo=").append(fwdAlgo);
|
||||||
|
} else {
|
||||||
|
sb.append(", bwdFilterAlgo=").append(bwdFilterAlgo)
|
||||||
|
.append(", bwdDataAlgo=").append(bwdDataAlgo);
|
||||||
|
}
|
||||||
|
sb.append(", convolutionMode=").append(convolutionMode);
|
||||||
|
|
||||||
|
throw new RuntimeException(sb.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray activate(INDArray z, IActivation afn, boolean training) {
|
||||||
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
|
INDArray activation = z;
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareAction(z);
|
||||||
|
Pointer dstData = allocator.getPointer(z, context);
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||||
|
switch (afn.toString()) {
|
||||||
|
case "identity":
|
||||||
|
break;
|
||||||
|
case "sigmoid":
|
||||||
|
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID,
|
||||||
|
CUDNN_PROPAGATE_NAN, 0));
|
||||||
|
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
|
||||||
|
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
|
break;
|
||||||
|
case "relu":
|
||||||
|
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU,
|
||||||
|
CUDNN_PROPAGATE_NAN, 0));
|
||||||
|
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
|
||||||
|
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
|
break;
|
||||||
|
case "tanh":
|
||||||
|
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH,
|
||||||
|
CUDNN_PROPAGATE_NAN, 0));
|
||||||
|
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
|
||||||
|
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
|
break;
|
||||||
|
case "softmax":
|
||||||
|
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
|
||||||
|
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
|
break;
|
||||||
|
case "logsoftmax":
|
||||||
|
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
|
||||||
|
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
activation = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
allocator.registerAction(context, activation);
|
||||||
|
|
||||||
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
|
context.syncOldStream();
|
||||||
|
|
||||||
|
return activation;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param poolingType Used when preparing data for subsampling layers ONLY. Null for convolution layers
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation,
|
||||||
|
ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat format){
|
||||||
|
INDArray origInput = input;
|
||||||
|
|
||||||
|
//Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides
|
||||||
|
// are non-default for C order - even if they *should* be OK otherwise
|
||||||
|
if(input.isView() || !Shape.hasDefaultStridesForShape(input)){
|
||||||
|
input = input.dup('c');
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
int hIdx = nchw ? 2 : 1;
|
||||||
|
int wIdx = nchw ? 3 : 2;
|
||||||
|
|
||||||
|
val inH = input.size(hIdx);
|
||||||
|
val inW = input.size(wIdx);
|
||||||
|
|
||||||
|
boolean manualPadBottom = false;
|
||||||
|
boolean manualPadRight = false;
|
||||||
|
|
||||||
|
int[] outSize;
|
||||||
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
|
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
|
||||||
|
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
||||||
|
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
||||||
|
if(!Arrays.equals(padding, padBottomRight)){
|
||||||
|
/*
|
||||||
|
CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
|
||||||
|
padding) - padding can *only* be specified as the same amount for both the top/bottom, and for left/right.
|
||||||
|
In SAME mode padding, sometimes these are the same - but often they are not.
|
||||||
|
Note that when they differ, the bottom or right padding will be exactly 1 more than the top or left padding.
|
||||||
|
As per TF, we'll manually pad here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/conv_ops.cc#L571-L607
|
||||||
|
*/
|
||||||
|
manualPadBottom = (padding[0] != padBottomRight[0]);
|
||||||
|
manualPadRight = (padding[1] != padBottomRight[1]);
|
||||||
|
|
||||||
|
//NCHW format
|
||||||
|
long[] newShape;
|
||||||
|
if(nchw){
|
||||||
|
newShape = new long[]{input.size(0), input.size(1),
|
||||||
|
input.size(2) + (manualPadBottom ? 1 : 0),
|
||||||
|
input.size(3) + (manualPadRight ? 1 : 0)};
|
||||||
|
} else {
|
||||||
|
newShape = new long[]{input.size(0),
|
||||||
|
input.size(1) + (manualPadBottom ? 1 : 0),
|
||||||
|
input.size(2) + (manualPadRight ? 1 : 0),
|
||||||
|
input.size(3)};
|
||||||
|
}
|
||||||
|
INDArray newInput;
|
||||||
|
if(poolingType == null || poolingType != PoolingType.MAX){
|
||||||
|
newInput = Nd4j.create(input.dataType(), newShape);
|
||||||
|
} else {
|
||||||
|
//For max pooling, we don't want to include the padding in the maximum values. But, CuDNN doesn't knowm
|
||||||
|
// that these values are padding and hence should be excluded. Instead: We'll use -infinity so that,
|
||||||
|
// if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value
|
||||||
|
newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
if(nchw){
|
||||||
|
newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)),
|
||||||
|
interval(0, input.size(3))}, input);
|
||||||
|
} else {
|
||||||
|
newInput.put(new INDArrayIndex[]{all(), interval(0,input.size(1)),
|
||||||
|
interval(0, input.size(2)), all()}, input);
|
||||||
|
}
|
||||||
|
|
||||||
|
input = newInput;
|
||||||
|
//Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we
|
||||||
|
// now have the same amount of padding required for top/bottom, and left/right - which we'll let
|
||||||
|
// CuDNN handle
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Data
|
||||||
|
public static class CudnnForwardArgs {
|
||||||
|
private boolean manualPadBottom;
|
||||||
|
private boolean manualPadRight;
|
||||||
|
private INDArray input;
|
||||||
|
private INDArray origInput;
|
||||||
|
private int[] padding;
|
||||||
|
private int[] outSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Long> helperMemoryUse() {
|
||||||
|
//No memory use other than shared, and the structs (which are small)
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,308 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.cuda.convolution.subsampling;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
||||||
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
||||||
|
import org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper;
|
||||||
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
||||||
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.bytedeco.cuda.cudart.*;
|
||||||
|
import org.bytedeco.cuda.cudnn.*;
|
||||||
|
|
||||||
|
import static org.bytedeco.cuda.global.cudnn.*;
|
||||||
|
import static org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper.getCudnnForwardArgs;
|
||||||
|
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
|
||||||
|
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* cuDNN-based helper for the subsampling layer.
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class CudnnSubsamplingHelper extends BaseCudnnHelper implements SubsamplingHelper {
|
||||||
|
|
||||||
|
public CudnnSubsamplingHelper(DataType dataType) {
|
||||||
|
super(dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class CudnnSubsamplingContext extends CudnnContext {
|
||||||
|
|
||||||
|
private static class Deallocator extends CudnnSubsamplingContext implements Pointer.Deallocator {
|
||||||
|
Deallocator(CudnnSubsamplingContext c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deallocate() {
|
||||||
|
destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
|
||||||
|
deltaTensorDesc = new cudnnTensorStruct();
|
||||||
|
private cudnnPoolingStruct poolingDesc = new cudnnPoolingStruct();
|
||||||
|
|
||||||
|
public CudnnSubsamplingContext() {
|
||||||
|
createHandles();
|
||||||
|
deallocator(new Deallocator(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
public CudnnSubsamplingContext(CudnnSubsamplingContext c) {
|
||||||
|
super(c);
|
||||||
|
srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
|
||||||
|
dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
|
||||||
|
deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
|
||||||
|
poolingDesc = new cudnnPoolingStruct(c.poolingDesc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void createHandles() {
|
||||||
|
super.createHandles();
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
|
||||||
|
checkCudnn(cudnnCreatePoolingDescriptor(poolingDesc));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void destroyHandles() {
|
||||||
|
checkCudnn(cudnnDestroyPoolingDescriptor(poolingDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
|
||||||
|
super.destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private CudnnSubsamplingContext cudnnContext = new CudnnSubsamplingContext();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides,
|
||||||
|
int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode,
|
||||||
|
int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
if(dilation[0] != 1 || dilation[1] != 1){
|
||||||
|
//CuDNN doesn't support dilated subsampling
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
int chIdx = nchw ? 1 : 3;
|
||||||
|
int hIdx = nchw ? 2 : 1;
|
||||||
|
int wIdx = nchw ? 3 : 2;
|
||||||
|
|
||||||
|
//We require the output as one of the arguments for backprop here
|
||||||
|
//TODO we could add cache mode support here somehow...
|
||||||
|
INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr);
|
||||||
|
|
||||||
|
val miniBatch = input.size(0);
|
||||||
|
val depth = input.size(chIdx);
|
||||||
|
|
||||||
|
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
|
||||||
|
input = args.getInput();
|
||||||
|
val inH = input.size(hIdx);
|
||||||
|
val inW = input.size(wIdx);
|
||||||
|
val srcStride = input.stride();
|
||||||
|
int[] outSize = args.getOutSize();
|
||||||
|
int outH = outSize[0];
|
||||||
|
int outW = outSize[1];
|
||||||
|
|
||||||
|
//subsampling doesn't have weights and thus gradients are not calculated for this layer
|
||||||
|
//only scale and reshape epsilon
|
||||||
|
Gradient retGradient = new DefaultGradient();
|
||||||
|
|
||||||
|
//Epsilons in shape: [miniBatch, channels, outH, outW]
|
||||||
|
//Epsilons out shape: [miniBatch, channels, inH, inW]
|
||||||
|
|
||||||
|
int poolingMode;
|
||||||
|
switch (poolingType) {
|
||||||
|
case AVG:
|
||||||
|
poolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||||
|
break;
|
||||||
|
case MAX:
|
||||||
|
poolingMode = CUDNN_POOLING_MAX;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Shape.hasDefaultStridesForShape(epsilon) || epsilon.isView()) {
|
||||||
|
// apparently not supported by cuDNN
|
||||||
|
epsilon = epsilon.dup('c');
|
||||||
|
}
|
||||||
|
|
||||||
|
input = input.dup();
|
||||||
|
|
||||||
|
val deltaStride = epsilon.stride();
|
||||||
|
|
||||||
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
||||||
|
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW,
|
||||||
|
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
|
||||||
|
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
|
||||||
|
kernel[1], pad[0], pad[1], strides[0], strides[1]));
|
||||||
|
|
||||||
|
long[] outEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
|
||||||
|
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c');
|
||||||
|
|
||||||
|
val dstStride = outEpsilon.stride();
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
||||||
|
(int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon);
|
||||||
|
Pointer srcData = allocator.getPointer(input, context);
|
||||||
|
Pointer epsData = allocator.getPointer(epsilon, context);
|
||||||
|
Pointer zData = allocator.getPointer(reduced, context);
|
||||||
|
Pointer dstData = allocator.getPointer(outEpsilon, context);
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||||
|
checkCudnn(cudnnPoolingBackward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.deltaTensorDesc,
|
||||||
|
zData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.srcTensorDesc, srcData, beta,
|
||||||
|
cudnnContext.dstTensorDesc, dstData));
|
||||||
|
|
||||||
|
allocator.registerAction(context, outEpsilon, input, epsilon, reduced);
|
||||||
|
|
||||||
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
|
context.syncOldStream();
|
||||||
|
|
||||||
|
//Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
|
||||||
|
// we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
|
||||||
|
if(args.isManualPadBottom() || args.isManualPadRight()) {
|
||||||
|
if(nchw){
|
||||||
|
outEpsilon = outEpsilon.get(all(), all(),
|
||||||
|
interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)),
|
||||||
|
interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0)));
|
||||||
|
} else {
|
||||||
|
outEpsilon = outEpsilon.get(all(),
|
||||||
|
interval(0, outEpsilon.size(1) - (args.isManualPadBottom() ? 1 : 0)),
|
||||||
|
interval(0, outEpsilon.size(2) - (args.isManualPadRight() ? 1 : 0)),
|
||||||
|
all());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Pair<>(retGradient, outEpsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad,
|
||||||
|
PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
if(dilation[0] != 1 || dilation[1] != 1){
|
||||||
|
//CuDNN doesn't support dilated subsampling
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
int chIdx = nchw ? 1 : 3;
|
||||||
|
int hIdx = nchw ? 2 : 1;
|
||||||
|
int wIdx = nchw ? 3 : 2;
|
||||||
|
|
||||||
|
val miniBatch = input.size(0);
|
||||||
|
val inDepth = input.size(nchw ? 1 : 3);
|
||||||
|
|
||||||
|
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
|
||||||
|
input = args.getInput();
|
||||||
|
val inH = input.size(nchw ? 2 : 1);
|
||||||
|
val inW = input.size(nchw ? 3 : 2);
|
||||||
|
val srcStride = input.stride();
|
||||||
|
val outSize = args.getOutSize();
|
||||||
|
int outH = outSize[0];
|
||||||
|
int outW = outSize[1];
|
||||||
|
|
||||||
|
|
||||||
|
int poolingMode;
|
||||||
|
switch (poolingType) {
|
||||||
|
case AVG:
|
||||||
|
poolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||||
|
break;
|
||||||
|
case MAX:
|
||||||
|
poolingMode = CUDNN_POOLING_MAX;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
|
||||||
|
kernel[1], pad[0], pad[1], strides[0], strides[1]));
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
|
||||||
|
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
|
||||||
|
|
||||||
|
long[] outShape = nchw ? new long[] {miniBatch, inDepth, outH, outW} : new long[] {miniBatch, outH, outW, inDepth};
|
||||||
|
INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c');
|
||||||
|
|
||||||
|
val dstStride = reduced.stride();
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW,
|
||||||
|
(int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareAction(input, reduced);
|
||||||
|
Pointer srcData = allocator.getPointer(input, context);
|
||||||
|
Pointer dstData = allocator.getPointer(reduced, context);
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||||
|
checkCudnn(cudnnPoolingForward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.srcTensorDesc,
|
||||||
|
srcData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
|
|
||||||
|
allocator.registerAction(context, reduced, input);
|
||||||
|
|
||||||
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
|
context.syncOldStream();
|
||||||
|
|
||||||
|
return reduced;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Long> helperMemoryUse() {
|
||||||
|
//No persistent memory use other than the structs (which are small)
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,245 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.cuda.dropout;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import com.jakewharton.byteunits.BinaryByteUnit;
|
||||||
|
import org.bytedeco.javacpp.*;
|
||||||
|
import org.deeplearning4j.nn.conf.dropout.DropoutHelper;
|
||||||
|
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
||||||
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
|
import org.nd4j.common.util.ArrayUtil;
|
||||||
|
|
||||||
|
import org.bytedeco.cuda.cudart.*;
|
||||||
|
import org.bytedeco.cuda.cudnn.*;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.bytedeco.cuda.global.cudnn.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CuDNN dropout helper
|
||||||
|
*
|
||||||
|
* Note that for repeatability between calls (for example, for gradient checks), we need to do two things:
|
||||||
|
* (a) set the ND4J RNG seed
|
||||||
|
* (b) clear the rngStates field
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Slf4j
|
||||||
|
public class CudnnDropoutHelper extends BaseCudnnHelper implements DropoutHelper {
|
||||||
|
|
||||||
|
private static class CudnnDropoutContext extends CudnnContext {
|
||||||
|
|
||||||
|
private static class Deallocator extends CudnnDropoutContext implements Pointer.Deallocator {
|
||||||
|
Deallocator(CudnnDropoutContext c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deallocate() {
|
||||||
|
destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private cudnnTensorStruct xTensorDesc = new cudnnTensorStruct(); //Input
|
||||||
|
private cudnnTensorStruct dxTensorDesc = new cudnnTensorStruct(); //Grad at input
|
||||||
|
private cudnnTensorStruct yTensorDesc = new cudnnTensorStruct(); //Output
|
||||||
|
private cudnnTensorStruct dyTensorDesc = new cudnnTensorStruct(); //Grad at output
|
||||||
|
private cudnnDropoutStruct dropoutDesc = new cudnnDropoutStruct();
|
||||||
|
|
||||||
|
public CudnnDropoutContext() {
|
||||||
|
createHandles();
|
||||||
|
deallocator(new Deallocator(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
public CudnnDropoutContext(CudnnDropoutContext c) {
|
||||||
|
super(c);
|
||||||
|
xTensorDesc = new cudnnTensorStruct(c.xTensorDesc);
|
||||||
|
dxTensorDesc = new cudnnTensorStruct(c.dxTensorDesc);
|
||||||
|
yTensorDesc = new cudnnTensorStruct(c.yTensorDesc);
|
||||||
|
dyTensorDesc = new cudnnTensorStruct(c.dyTensorDesc);
|
||||||
|
dropoutDesc = new cudnnDropoutStruct(c.dropoutDesc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void createHandles() {
|
||||||
|
super.createHandles();
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(xTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(dxTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(yTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(dyTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateDropoutDescriptor(dropoutDesc));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void destroyHandles() {
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(xTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(dxTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(yTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(dyTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyDropoutDescriptor(dropoutDesc));
|
||||||
|
super.destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private CudnnDropoutContext cudnnContext = new CudnnDropoutContext();
|
||||||
|
private boolean initializedDescriptor = false;
|
||||||
|
private DataCache rngStates; //"Pointer to user-allocated GPU memory that will hold random number generator states."
|
||||||
|
private DataCache mask; //Mask: persistence between forward and backward
|
||||||
|
private SizeTPointer stateSizeBytesPtr;
|
||||||
|
private SizeTPointer reserveSizeBytesPtr;
|
||||||
|
private float lastInitializedP;
|
||||||
|
|
||||||
|
public CudnnDropoutHelper(DataType dataType){
|
||||||
|
super(dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
//@Override
|
||||||
|
public Map<String, Long> helperMemoryUse() {
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean checkSupported() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void applyDropout(INDArray input, INDArray resultArray, double dropoutInputRetainProb) {
|
||||||
|
float p = (float)(1.0 - dropoutInputRetainProb); //CuDNN uses p = probability of setting to 0. We use p = probability of retaining
|
||||||
|
|
||||||
|
//TODO int cast
|
||||||
|
int[] inShape = adaptForTensorDescr(ArrayUtil.toInts(input.shape()));
|
||||||
|
int[] inStride = adaptForTensorDescr(ArrayUtil.toInts(input.stride()));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.xTensorDesc, dataType, inShape.length, inShape, inStride));
|
||||||
|
|
||||||
|
int[] outShape = adaptForTensorDescr(ArrayUtil.toInts(resultArray.shape()));
|
||||||
|
int[] outStride = adaptForTensorDescr(ArrayUtil.toInts(resultArray.stride()));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.yTensorDesc, dataType, outShape.length, outShape, outStride));
|
||||||
|
|
||||||
|
|
||||||
|
if(stateSizeBytesPtr == null){
|
||||||
|
stateSizeBytesPtr = new SizeTPointer(1);
|
||||||
|
reserveSizeBytesPtr = new SizeTPointer(1);
|
||||||
|
}
|
||||||
|
checkCudnn(cudnnDropoutGetStatesSize(cudnnContext, stateSizeBytesPtr));
|
||||||
|
long rngStateSizeBytes = stateSizeBytesPtr.get();
|
||||||
|
checkCudnn(cudnnDropoutGetReserveSpaceSize(cudnnContext.xTensorDesc, reserveSizeBytesPtr));
|
||||||
|
long maskReserveSizeBytes = reserveSizeBytesPtr.get();
|
||||||
|
|
||||||
|
if(rngStates == null || rngStates.capacity() < rngStateSizeBytes){
|
||||||
|
if(log.isTraceEnabled()){
|
||||||
|
if(rngStates == null){
|
||||||
|
log.trace("CudnnDropoutHelper: Allocating intial RNG states workspace of size {} ({})", rngStateSizeBytes,
|
||||||
|
BinaryByteUnit.format(rngStateSizeBytes, "#.00"));
|
||||||
|
} else {
|
||||||
|
log.trace("CudnnDropoutHelper: Deallocating RNG states of size {} ({}), allocating new workspace of size {} ({})",
|
||||||
|
rngStates.capacity(), BinaryByteUnit.format(rngStates.capacity(), "#.00"),
|
||||||
|
rngStateSizeBytes, BinaryByteUnit.format(rngStateSizeBytes, "#.00"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(rngStates != null)
|
||||||
|
rngStates.deallocate();
|
||||||
|
//states = "Pointer to user-allocated GPU memory that will hold random number generator states."
|
||||||
|
rngStates = new DataCache(rngStateSizeBytes);
|
||||||
|
initializedDescriptor = false;
|
||||||
|
}
|
||||||
|
if(mask == null || mask.capacity() < maskReserveSizeBytes){
|
||||||
|
if(log.isTraceEnabled()){
|
||||||
|
if(mask == null){
|
||||||
|
log.trace("CudnnDropoutHelper: Allocating intial mask array of size {} ({})", maskReserveSizeBytes,
|
||||||
|
BinaryByteUnit.format(maskReserveSizeBytes, "#.00"));
|
||||||
|
} else {
|
||||||
|
log.trace("CudnnDropoutHelper: Deallocating mask array of size {} ({}), allocating new mask array of size {} ({})",
|
||||||
|
mask.capacity(), BinaryByteUnit.format(mask.capacity(), "#.00"),
|
||||||
|
maskReserveSizeBytes, BinaryByteUnit.format(maskReserveSizeBytes, "#.00"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(mask != null)
|
||||||
|
mask.deallocate();
|
||||||
|
//mask = "Pointer to user-allocated GPU memory used by this function. It is expected
|
||||||
|
//that contents of reserveSpace doe not change between cudnnDropoutForward and
|
||||||
|
//cudnnDropoutBackward calls."
|
||||||
|
mask = new DataCache(maskReserveSizeBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Dropout descriptor: (re)initialize if required
|
||||||
|
if(!initializedDescriptor || p != lastInitializedP) {
|
||||||
|
if(log.isTraceEnabled()){
|
||||||
|
log.trace("CudnnDropoutHelper: (re)initializing dropout descriptor");
|
||||||
|
}
|
||||||
|
//NOTE: cudnnSetDropoutDescriptor has some internal computation/initialization, and hence is expensive to
|
||||||
|
// call - so we want to call this as infrequently as possible, and cache the result
|
||||||
|
long seed = Nd4j.getRandom().nextLong();
|
||||||
|
lastInitializedP = p;
|
||||||
|
checkCudnn(cudnnSetDropoutDescriptor(cudnnContext.dropoutDesc, cudnnContext, p, rngStates, rngStates.capacity(), seed));
|
||||||
|
initializedDescriptor = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareAction(input, resultArray);
|
||||||
|
Pointer xPtr = allocator.getPointer(input, context);
|
||||||
|
Pointer yPtr = allocator.getPointer(resultArray, context);
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||||
|
checkCudnn(cudnnDropoutForward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.xTensorDesc, xPtr,
|
||||||
|
cudnnContext.yTensorDesc, yPtr, mask, mask.capacity()));
|
||||||
|
|
||||||
|
allocator.registerAction(context, input, resultArray);
|
||||||
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
|
context.syncOldStream();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void backprop(INDArray gradAtOutput, INDArray gradAtInput) {
|
||||||
|
int[] gradAtOutShape = adaptForTensorDescr(ArrayUtil.toInts(gradAtOutput.shape()));
|
||||||
|
int[] gradAtOutStride = adaptForTensorDescr(ArrayUtil.toInts(gradAtOutput.stride()));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dyTensorDesc, dataType, gradAtOutShape.length, gradAtOutShape, gradAtOutStride));
|
||||||
|
|
||||||
|
int[] gradAtInShape = adaptForTensorDescr(ArrayUtil.toInts(gradAtInput.shape()));
|
||||||
|
int[] gradAtInStride = adaptForTensorDescr(ArrayUtil.toInts(gradAtInput.stride()));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dxTensorDesc, dataType, gradAtInShape.length, gradAtInShape, gradAtInStride));
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareAction(gradAtOutput, gradAtInput);
|
||||||
|
Pointer dyPtr = allocator.getPointer(gradAtOutput, context);
|
||||||
|
Pointer dxPtr = allocator.getPointer(gradAtInput, context);
|
||||||
|
|
||||||
|
checkCudnn(cudnnDropoutBackward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.dyTensorDesc, dyPtr,
|
||||||
|
cudnnContext.dxTensorDesc, dxPtr, mask, mask.capacity()));
|
||||||
|
|
||||||
|
allocator.registerAction(context, gradAtOutput, gradAtInput);
|
||||||
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
|
context.syncOldStream();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,384 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.cuda.normalization;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
|
||||||
|
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
|
||||||
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
||||||
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.common.util.ArrayUtil;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.bytedeco.cuda.cudart.*;
|
||||||
|
import org.bytedeco.cuda.cudnn.*;
|
||||||
|
|
||||||
|
import static org.bytedeco.cuda.global.cudnn.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* cuDNN-based helper for the batch normalization layer.
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements BatchNormalizationHelper {
|
||||||
|
|
||||||
|
public CudnnBatchNormalizationHelper(DataType dataType) {
|
||||||
|
super(dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class CudnnBatchNormalizationContext extends CudnnContext {
|
||||||
|
|
||||||
|
private static class Deallocator extends CudnnBatchNormalizationContext implements Pointer.Deallocator {
|
||||||
|
Deallocator(CudnnBatchNormalizationContext c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deallocate() {
|
||||||
|
destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
|
||||||
|
deltaTensorDesc = new cudnnTensorStruct(), gammaBetaTensorDesc = new cudnnTensorStruct();
|
||||||
|
|
||||||
|
public CudnnBatchNormalizationContext() {
|
||||||
|
createHandles();
|
||||||
|
deallocator(new Deallocator(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
public CudnnBatchNormalizationContext(CudnnBatchNormalizationContext c) {
|
||||||
|
super(c);
|
||||||
|
srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
|
||||||
|
dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
|
||||||
|
deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
|
||||||
|
gammaBetaTensorDesc = new cudnnTensorStruct(c.gammaBetaTensorDesc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void createHandles() {
|
||||||
|
super.createHandles();
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(gammaBetaTensorDesc));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void destroyHandles() {
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(gammaBetaTensorDesc));
|
||||||
|
super.destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected final int batchNormMode = CUDNN_BATCHNORM_SPATIAL; // would need to increase rank of gamma and beta for CUDNN_BATCHNORM_PER_ACTIVATION
|
||||||
|
|
||||||
|
private CudnnBatchNormalizationContext cudnnContext = new CudnnBatchNormalizationContext();
|
||||||
|
private INDArray meanCache;
|
||||||
|
private INDArray varCache;
|
||||||
|
private double eps;
|
||||||
|
|
||||||
|
public boolean checkSupported(double eps, boolean isFixedGammaBeta) {
|
||||||
|
boolean supported = checkSupported();
|
||||||
|
if (eps < CUDNN_BN_MIN_EPSILON) {
|
||||||
|
supported = false;
|
||||||
|
log.warn("Not supported: eps < CUDNN_BN_MIN_EPSILON (" + eps + " < " + CUDNN_BN_MIN_EPSILON + ")");
|
||||||
|
}
|
||||||
|
return supported;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
|
||||||
|
INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) {
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
|
this.eps = eps;
|
||||||
|
|
||||||
|
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||||
|
int chIdx = nchw ? 1 : 3;
|
||||||
|
int hIdx = nchw ? 2 : 1;
|
||||||
|
int wIdx = nchw ? 3 : 2;
|
||||||
|
|
||||||
|
val miniBatch = (int) input.size(0);
|
||||||
|
val depth = (int) input.size(chIdx);
|
||||||
|
val inH = (int) input.size(hIdx);
|
||||||
|
val inW = (int) input.size(wIdx);
|
||||||
|
|
||||||
|
final boolean isHalf = (input.dataType() == DataType.HALF);
|
||||||
|
INDArray gammaOrig = null;
|
||||||
|
INDArray dGammaViewOrig = null;
|
||||||
|
INDArray dBetaViewOrig = null;
|
||||||
|
if(isHalf) { //Convert FP16 to FP32 if required (CuDNN BN doesn't support FP16 for these params, only for input/output)
|
||||||
|
gammaOrig = gamma;
|
||||||
|
dGammaViewOrig = dGammaView;
|
||||||
|
dBetaViewOrig = dBetaView;
|
||||||
|
/*
|
||||||
|
From CuDNN docs: bnScale, resultBnScaleDiff, resultBnBiasDiff, savedMean, savedInvVariance
|
||||||
|
"Note: The data type of this tensor descriptor must be 'float' for FP16 and FP32 input tensors, and 'double'
|
||||||
|
for FP64 input tensors."
|
||||||
|
>> Last 2 are the meanCache and varCache; first 3 are below
|
||||||
|
*/
|
||||||
|
gamma = gamma.castTo(DataType.FLOAT);
|
||||||
|
dGammaView = dGammaView.castTo(DataType.FLOAT);
|
||||||
|
dBetaView = dBetaView.castTo(DataType.FLOAT);
|
||||||
|
}
|
||||||
|
|
||||||
|
Gradient retGradient = new DefaultGradient();
|
||||||
|
|
||||||
|
if (!Shape.hasDefaultStridesForShape(epsilon)) {
|
||||||
|
// apparently not supported by cuDNN
|
||||||
|
epsilon = epsilon.dup('c');
|
||||||
|
}
|
||||||
|
|
||||||
|
val srcStride = ArrayUtil.toInts(input.stride());
|
||||||
|
val deltaStride = ArrayUtil.toInts(epsilon.stride());
|
||||||
|
|
||||||
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
||||||
|
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
||||||
|
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
|
||||||
|
|
||||||
|
long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
|
||||||
|
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c');
|
||||||
|
val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
|
||||||
|
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
|
||||||
|
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma,
|
||||||
|
dGammaView, dBetaView);
|
||||||
|
Pointer srcData = allocator.getPointer(input, context);
|
||||||
|
Pointer epsData = allocator.getPointer(epsilon, context);
|
||||||
|
Pointer dstData = allocator.getPointer(nextEpsilon, context);
|
||||||
|
Pointer gammaData = allocator.getPointer(gamma, context);
|
||||||
|
Pointer dGammaData = allocator.getPointer(dGammaView, context);
|
||||||
|
Pointer dBetaData = allocator.getPointer(dBetaView, context);
|
||||||
|
Pointer meanCacheData = allocator.getPointer(meanCache, context);
|
||||||
|
Pointer varCacheData = allocator.getPointer(varCache, context);
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||||
|
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha,
|
||||||
|
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
|
||||||
|
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
|
||||||
|
dBetaData, eps, meanCacheData, varCacheData));
|
||||||
|
|
||||||
|
allocator.getFlowController().registerActionAllWrite(context, input, epsilon, nextEpsilon, gamma, dGammaView,
|
||||||
|
dBetaView);
|
||||||
|
|
||||||
|
retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
|
||||||
|
retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
|
||||||
|
|
||||||
|
context.syncOldStream();
|
||||||
|
|
||||||
|
//Convert back and assign, if required:
|
||||||
|
if(isHalf){
|
||||||
|
gammaOrig.assign(gamma.castTo(DataType.HALF));
|
||||||
|
dGammaViewOrig.assign(dGammaView.castTo(DataType.HALF));
|
||||||
|
dBetaViewOrig.assign(dBetaView.castTo(DataType.HALF));
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Pair<>(retGradient, nextEpsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
||||||
|
INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||||
|
int chIdx = nchw ? 1 : 3;
|
||||||
|
int hIdx = nchw ? 2 : 1;
|
||||||
|
int wIdx = nchw ? 3 : 2;
|
||||||
|
|
||||||
|
this.eps = eps;
|
||||||
|
final boolean isHalf = (x.dataType() == DataType.FLOAT16);
|
||||||
|
INDArray origGamma = gamma;
|
||||||
|
INDArray origBeta = beta;
|
||||||
|
INDArray origMean = mean;
|
||||||
|
INDArray origVar = var;
|
||||||
|
if(isHalf) {
|
||||||
|
gamma = gamma.castTo(DataType.FLOAT);
|
||||||
|
beta = beta.castTo(DataType.FLOAT);
|
||||||
|
mean = mean.castTo(DataType.FLOAT);
|
||||||
|
var = var.castTo(DataType.FLOAT);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Notation difference between CuDNN and our implementation:
|
||||||
|
//Us: runningMean = (1-decay) * batchMean + decay * runningMean
|
||||||
|
//CuDNN: runningMean = decay * batchMean + (1-decay) * runningMean
|
||||||
|
//i.e., "decay" has a different meaning...
|
||||||
|
//Disable in-place updating of running mean/variance, so that all parameter changes are done via the update/gradient
|
||||||
|
// vector. This is necessary for BatchNormalization to be safe to use in distributed gradient sharing settings
|
||||||
|
decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled"
|
||||||
|
|
||||||
|
val miniBatch = (int) x.size(0);
|
||||||
|
val inDepth = (int) x.size(chIdx);
|
||||||
|
val inH = (int) x.size(hIdx);
|
||||||
|
val inW = (int) x.size(wIdx);
|
||||||
|
|
||||||
|
val srcStride = ArrayUtil.toInts(x.stride());
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
||||||
|
srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx]));
|
||||||
|
|
||||||
|
long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth};
|
||||||
|
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c');
|
||||||
|
|
||||||
|
val dstStride = ArrayUtil.toInts(activations.stride());
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
||||||
|
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0],
|
||||||
|
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context =
|
||||||
|
allocator.getFlowController().prepareActionAllWrite(x, activations, gamma, beta, mean, var);
|
||||||
|
Pointer srcData = allocator.getPointer(x, context);
|
||||||
|
Pointer dstData = allocator.getPointer(activations, context);
|
||||||
|
Pointer gammaData = allocator.getPointer(gamma, context);
|
||||||
|
Pointer betaData = allocator.getPointer(beta, context);
|
||||||
|
Pointer meanData = allocator.getPointer(mean, context);
|
||||||
|
Pointer varData = allocator.getPointer(var, context);
|
||||||
|
|
||||||
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||||
|
if (training) {
|
||||||
|
if(meanCache == null || meanCache.length() < mean.length()){
|
||||||
|
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||||
|
meanCache = Nd4j.createUninitialized(x.dataType(), mean.length());
|
||||||
|
}
|
||||||
|
if(x.dataType() == DataType.HALF){
|
||||||
|
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||||
|
meanCache = meanCache.castTo(DataType.FLOAT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(varCache == null || varCache.length() < mean.length()){
|
||||||
|
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||||
|
varCache = Nd4j.createUninitialized(x.dataType(), mean.length());
|
||||||
|
}
|
||||||
|
if(nd4jDataType == DataType.HALF){
|
||||||
|
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||||
|
varCache = varCache.castTo(DataType.FLOAT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Pointer meanCacheData = allocator.getPointer(meanCache, context);
|
||||||
|
Pointer varCacheData = allocator.getPointer(varCache, context);
|
||||||
|
|
||||||
|
checkCudnn(cudnnBatchNormalizationForwardTraining(cudnnContext, batchNormMode, this.alpha, this.beta,
|
||||||
|
cudnnContext.srcTensorDesc, srcData, cudnnContext.dstTensorDesc, dstData,
|
||||||
|
cudnnContext.gammaBetaTensorDesc, gammaData, betaData, decay, meanData, varData, eps,
|
||||||
|
meanCacheData, varCacheData));
|
||||||
|
} else {
|
||||||
|
checkCudnn(cudnnBatchNormalizationForwardInference(cudnnContext, batchNormMode, this.alpha, this.beta,
|
||||||
|
cudnnContext.srcTensorDesc, srcData, cudnnContext.dstTensorDesc, dstData,
|
||||||
|
cudnnContext.gammaBetaTensorDesc, gammaData, betaData, meanData, varData, eps));
|
||||||
|
}
|
||||||
|
|
||||||
|
allocator.getFlowController().registerActionAllWrite(context, x, activations, gamma, beta, mean, var);
|
||||||
|
|
||||||
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
|
context.syncOldStream();
|
||||||
|
|
||||||
|
context.syncOldStream();
|
||||||
|
if(training) {
|
||||||
|
AtomicAllocator.getInstance().getAllocationPoint(meanCache).tickDeviceWrite();
|
||||||
|
AtomicAllocator.getInstance().getAllocationPoint(varCache).tickDeviceWrite();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(training && isHalf){
|
||||||
|
//Update the running mean and variance arrays; also gamma/beta
|
||||||
|
origMean.assign(mean.castTo(DataType.HALF));
|
||||||
|
origVar.assign(var.castTo(DataType.HALF));
|
||||||
|
origGamma.assign(gamma.castTo(DataType.HALF));
|
||||||
|
origBeta.assign(beta.castTo(DataType.HALF));
|
||||||
|
}
|
||||||
|
|
||||||
|
return activations;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getMeanCache(DataType dataType) {
|
||||||
|
if(dataType == DataType.HALF){
|
||||||
|
//Buffer is FP32
|
||||||
|
return meanCache.castTo(DataType.HALF);
|
||||||
|
}
|
||||||
|
return meanCache;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getVarCache(DataType dataType) {
|
||||||
|
INDArray ret;
|
||||||
|
if(dataType == DataType.HALF){
|
||||||
|
INDArray vc = varCache.castTo(DataType.HALF);
|
||||||
|
ret = vc.mul(vc).rdivi(1.0).subi(eps);
|
||||||
|
} else {
|
||||||
|
ret = varCache.mul(varCache).rdivi(1.0).subi(eps);
|
||||||
|
}
|
||||||
|
if(dataType == DataType.HALF){
|
||||||
|
//Buffer is FP32
|
||||||
|
return ret.castTo(DataType.HALF);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Long> helperMemoryUse() {
|
||||||
|
Map<String,Long> memUse = new HashMap<>();
|
||||||
|
memUse.put("meanCache", meanCache == null ? 0 : meanCache.length() * meanCache.data().getElementSize());
|
||||||
|
memUse.put("varCache", varCache == null ? 0 : varCache.length() * varCache.data().getElementSize());
|
||||||
|
return memUse;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,240 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.cuda.normalization;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper;
|
||||||
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
||||||
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
import org.nd4j.common.util.ArrayUtil;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.bytedeco.cuda.cudart.*;
|
||||||
|
import org.bytedeco.cuda.cudnn.*;
|
||||||
|
|
||||||
|
import static org.bytedeco.cuda.global.cudnn.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* cuDNN-based helper for the local response normalization layer.
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper implements LocalResponseNormalizationHelper {
|
||||||
|
|
||||||
|
public CudnnLocalResponseNormalizationHelper(DataType dataType) {
|
||||||
|
super(dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class CudnnLocalResponseNormalizationContext extends CudnnContext {
|
||||||
|
|
||||||
|
private static class Deallocator extends CudnnLocalResponseNormalizationContext implements Pointer.Deallocator {
|
||||||
|
Deallocator(CudnnLocalResponseNormalizationContext c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deallocate() {
|
||||||
|
destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
|
||||||
|
deltaTensorDesc = new cudnnTensorStruct();
|
||||||
|
private cudnnLRNStruct lrnDesc = new cudnnLRNStruct();
|
||||||
|
|
||||||
|
public CudnnLocalResponseNormalizationContext() {
|
||||||
|
createHandles();
|
||||||
|
deallocator(new Deallocator(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
public CudnnLocalResponseNormalizationContext(CudnnLocalResponseNormalizationContext c) {
|
||||||
|
super(c);
|
||||||
|
srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
|
||||||
|
dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
|
||||||
|
deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
|
||||||
|
lrnDesc = new cudnnLRNStruct(c.lrnDesc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void createHandles() {
|
||||||
|
super.createHandles();
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
|
||||||
|
checkCudnn(cudnnCreateLRNDescriptor(lrnDesc));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void destroyHandles() {
|
||||||
|
checkCudnn(cudnnDestroyLRNDescriptor(lrnDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
|
||||||
|
super.destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private CudnnLocalResponseNormalizationContext cudnnContext = new CudnnLocalResponseNormalizationContext();
|
||||||
|
private INDArray activations = null;
|
||||||
|
|
||||||
|
public boolean checkSupported(double k, double n, double alpha, double beta) {
|
||||||
|
boolean supported = checkSupported();
|
||||||
|
if (n < CUDNN_LRN_MIN_N) {
|
||||||
|
supported = false;
|
||||||
|
log.warn("Not supported: n < CUDNN_LRN_MIN_N (" + n + " < " + CUDNN_LRN_MIN_N + ")");
|
||||||
|
}
|
||||||
|
if (n > CUDNN_LRN_MAX_N) {
|
||||||
|
supported = false;
|
||||||
|
log.warn("Not supported: n > CUDNN_LRN_MAX_N (" + n + " > " + CUDNN_LRN_MAX_N + ")");
|
||||||
|
}
|
||||||
|
if (k < CUDNN_LRN_MIN_K) {
|
||||||
|
supported = false;
|
||||||
|
log.warn("Not supported: k < CUDNN_LRN_MIN_K (" + k + " < " + CUDNN_LRN_MIN_K + ")");
|
||||||
|
}
|
||||||
|
if (beta < CUDNN_LRN_MIN_BETA) {
|
||||||
|
supported = false;
|
||||||
|
log.warn("Not supported: beta < CUDNN_LRN_MIN_BETA (" + beta + " < " + CUDNN_LRN_MIN_BETA + ")");
|
||||||
|
}
|
||||||
|
return supported;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, double k, double n, double alpha,
|
||||||
|
double beta, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
val miniBatch = (int) input.size(0);
|
||||||
|
val depth = (int) input.size(1);
|
||||||
|
val inH = (int) input.size(2);
|
||||||
|
val inW = (int) input.size(3);
|
||||||
|
|
||||||
|
Gradient retGradient = new DefaultGradient();
|
||||||
|
|
||||||
|
if (!Shape.hasDefaultStridesForShape(epsilon)) {
|
||||||
|
// apparently not supported by cuDNN
|
||||||
|
epsilon = epsilon.dup('c');
|
||||||
|
}
|
||||||
|
|
||||||
|
val srcStride = ArrayUtil.toInts(input.stride());
|
||||||
|
val deltaStride = ArrayUtil.toInts(epsilon.stride());
|
||||||
|
|
||||||
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, depth, inH, inW,
|
||||||
|
srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, miniBatch, depth, inH, inW,
|
||||||
|
deltaStride[0], deltaStride[1], deltaStride[2], deltaStride[3]));
|
||||||
|
checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k));
|
||||||
|
|
||||||
|
INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c');
|
||||||
|
|
||||||
|
val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
|
||||||
|
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context =
|
||||||
|
allocator.getFlowController().prepareActionAllWrite(input, epsilon, activations, nextEpsilon);
|
||||||
|
Pointer srcData = allocator.getPointer(input, context);
|
||||||
|
Pointer epsData = allocator.getPointer(epsilon, context);
|
||||||
|
Pointer zData = allocator.getPointer(activations, context);
|
||||||
|
Pointer dstData = allocator.getPointer(nextEpsilon, context);
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||||
|
checkCudnn(cudnnLRNCrossChannelBackward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
|
||||||
|
this.alpha, cudnnContext.deltaTensorDesc, zData, cudnnContext.deltaTensorDesc, epsData,
|
||||||
|
cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
|
|
||||||
|
allocator.getFlowController().registerActionAllWrite(context, input, epsilon, activations, nextEpsilon);
|
||||||
|
|
||||||
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
|
context.syncOldStream();
|
||||||
|
|
||||||
|
return new Pair<>(retGradient, nextEpsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray activate(INDArray input, boolean training, double k, double n, double alpha, double beta, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
val miniBatch = (int) input.size(0);
|
||||||
|
val inDepth = (int) input.size(1);
|
||||||
|
val inH = (int) input.size(2);
|
||||||
|
val inW = (int) input.size(3);
|
||||||
|
|
||||||
|
if(!Shape.hasDefaultStridesForShape(input)){
|
||||||
|
input = input.dup('c');
|
||||||
|
}
|
||||||
|
|
||||||
|
val srcStride = ArrayUtil.toInts(input.stride());
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
||||||
|
srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
|
||||||
|
|
||||||
|
activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c');
|
||||||
|
|
||||||
|
val dstStride = ArrayUtil.toInts(activations.stride());
|
||||||
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
||||||
|
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
||||||
|
checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k));
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, activations);
|
||||||
|
Pointer srcData = allocator.getPointer(input, context);
|
||||||
|
Pointer dstData = allocator.getPointer(activations, context);
|
||||||
|
|
||||||
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||||
|
checkCudnn(cudnnLRNCrossChannelForward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
|
||||||
|
this.alpha, cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc,
|
||||||
|
dstData));
|
||||||
|
|
||||||
|
allocator.getFlowController().registerActionAllWrite(context, input, activations);
|
||||||
|
|
||||||
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
|
context.syncOldStream();
|
||||||
|
|
||||||
|
return activations;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Long> helperMemoryUse() {
|
||||||
|
//No persistent memory use other than the structs (which are small)
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,659 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.cuda.recurrent;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
|
import com.jakewharton.byteunits.BinaryByteUnit;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
|
||||||
|
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
|
||||||
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationTanH;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.bytedeco.cuda.cudart.*;
|
||||||
|
import org.bytedeco.cuda.cudnn.*;
|
||||||
|
import static org.bytedeco.cuda.global.cudart.*;
|
||||||
|
import static org.bytedeco.cuda.global.cudnn.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* cuDNN-based helper for the recurrent LSTM layer (no peephole connections).
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
|
||||||
|
|
||||||
|
public CudnnLSTMHelper(DataType dataType) {
|
||||||
|
super(dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class CudnnLSTMContext extends CudnnContext {
|
||||||
|
|
||||||
|
private static class Deallocator extends CudnnLSTMContext implements Pointer.Deallocator {
|
||||||
|
Deallocator(CudnnLSTMContext c) {
|
||||||
|
super(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deallocate() {
|
||||||
|
destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private cudnnTensorStruct hxDesc = new cudnnTensorStruct(), cxDesc = new cudnnTensorStruct();
|
||||||
|
private cudnnTensorStruct hyDesc = new cudnnTensorStruct(), cyDesc = new cudnnTensorStruct();
|
||||||
|
private cudnnTensorStruct dhxDesc = new cudnnTensorStruct(), dcxDesc = new cudnnTensorStruct();
|
||||||
|
private cudnnTensorStruct dhyDesc = new cudnnTensorStruct(), dcyDesc = new cudnnTensorStruct();
|
||||||
|
|
||||||
|
private cudnnFilterStruct wDesc = new cudnnFilterStruct(), dwDesc = new cudnnFilterStruct();
|
||||||
|
private cudnnFilterStruct linLayerMatDesc = new cudnnFilterStruct(), linLayerBiasDesc = new cudnnFilterStruct();
|
||||||
|
|
||||||
|
private cudnnRNNStruct rnnDesc = new cudnnRNNStruct();
|
||||||
|
private cudnnDropoutStruct dropoutDesc = new cudnnDropoutStruct();
|
||||||
|
private cudnnActivationStruct activationDesc = new cudnnActivationStruct();
|
||||||
|
|
||||||
|
public CudnnLSTMContext() {
|
||||||
|
createHandles();
|
||||||
|
deallocator(new Deallocator(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
public CudnnLSTMContext(CudnnLSTMContext c) {
|
||||||
|
super(c);
|
||||||
|
hxDesc = new cudnnTensorStruct(c.hxDesc);
|
||||||
|
cxDesc = new cudnnTensorStruct(c.cxDesc);
|
||||||
|
hyDesc = new cudnnTensorStruct(c.hyDesc);
|
||||||
|
cyDesc = new cudnnTensorStruct(c.cyDesc);
|
||||||
|
dhxDesc = new cudnnTensorStruct(c.dhxDesc);
|
||||||
|
dcxDesc = new cudnnTensorStruct(c.dcxDesc);
|
||||||
|
dhyDesc = new cudnnTensorStruct(c.dhyDesc);
|
||||||
|
dcyDesc = new cudnnTensorStruct(c.dcyDesc);
|
||||||
|
|
||||||
|
wDesc = new cudnnFilterStruct(c.wDesc);
|
||||||
|
dwDesc = new cudnnFilterStruct(c.dwDesc);
|
||||||
|
linLayerMatDesc = new cudnnFilterStruct(c.linLayerMatDesc);
|
||||||
|
linLayerBiasDesc = new cudnnFilterStruct(c.linLayerBiasDesc);
|
||||||
|
|
||||||
|
rnnDesc = new cudnnRNNStruct(c.rnnDesc);
|
||||||
|
dropoutDesc = new cudnnDropoutStruct(c.dropoutDesc);
|
||||||
|
activationDesc = new cudnnActivationStruct(c.activationDesc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void createHandles() {
|
||||||
|
super.createHandles();
|
||||||
|
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(hxDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(cxDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(hyDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(cyDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(dhxDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(dcxDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(dhyDesc));
|
||||||
|
checkCudnn(cudnnCreateTensorDescriptor(dcyDesc));
|
||||||
|
|
||||||
|
checkCudnn(cudnnCreateFilterDescriptor(wDesc));
|
||||||
|
checkCudnn(cudnnCreateFilterDescriptor(dwDesc));
|
||||||
|
checkCudnn(cudnnCreateFilterDescriptor(linLayerMatDesc));
|
||||||
|
checkCudnn(cudnnCreateFilterDescriptor(linLayerBiasDesc));
|
||||||
|
|
||||||
|
checkCudnn(cudnnCreateRNNDescriptor(rnnDesc));
|
||||||
|
checkCudnn(cudnnCreateDropoutDescriptor(dropoutDesc));
|
||||||
|
checkCudnn(cudnnCreateActivationDescriptor(activationDesc));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void destroyHandles() {
|
||||||
|
checkCudnn(cudnnDestroyActivationDescriptor(activationDesc));
|
||||||
|
checkCudnn(cudnnDestroyDropoutDescriptor(dropoutDesc));
|
||||||
|
checkCudnn(cudnnDestroyRNNDescriptor(rnnDesc));
|
||||||
|
|
||||||
|
checkCudnn(cudnnDestroyFilterDescriptor(wDesc));
|
||||||
|
checkCudnn(cudnnDestroyFilterDescriptor(dwDesc));
|
||||||
|
checkCudnn(cudnnDestroyFilterDescriptor(linLayerMatDesc));
|
||||||
|
checkCudnn(cudnnDestroyFilterDescriptor(linLayerBiasDesc));
|
||||||
|
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(hxDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(cxDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(hyDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(cyDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(dhxDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(dcxDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(dhyDesc));
|
||||||
|
checkCudnn(cudnnDestroyTensorDescriptor(dcyDesc));
|
||||||
|
|
||||||
|
super.destroyHandles();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// These constants might eventually become variable parameters...
|
||||||
|
protected static final int NUM_LAYERS = 1;
|
||||||
|
protected static final float DROPOUT = 0;
|
||||||
|
protected static final boolean BIDIRECTIONAL = false;
|
||||||
|
protected static final int RNN_MODE = CUDNN_LSTM;
|
||||||
|
protected static final int NUM_LINEAR_LAYERS = 8; // CUDNN_LSTM
|
||||||
|
|
||||||
|
private CudnnLSTMContext cudnnContext = new CudnnLSTMContext();
|
||||||
|
private TensorArray xDesc = new TensorArray();
|
||||||
|
private TensorArray yDesc = new TensorArray();
|
||||||
|
private TensorArray dxDesc = new TensorArray();
|
||||||
|
private TensorArray dyDesc = new TensorArray();
|
||||||
|
private DataCache stateSpace = new DataCache();
|
||||||
|
private DataCache reserveSpace = new DataCache();
|
||||||
|
private DataCache weightsSpace = new DataCache();
|
||||||
|
|
||||||
|
private boolean initializedDropoutDescriptor = false;
|
||||||
|
|
||||||
|
private static INDArray toCOrder(INDArray arr) {
|
||||||
|
if (arr.isView() || arr.ordering() != 'c' || !Shape.strideDescendingCAscendingF(arr)) {
|
||||||
|
arr = arr.dup('c');
|
||||||
|
}
|
||||||
|
return arr;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn,
|
||||||
|
boolean hasPeepholeConnections) {
|
||||||
|
boolean supported = checkSupported();
|
||||||
|
if (!(gateActivationFn instanceof ActivationSigmoid)) {
|
||||||
|
supported = false;
|
||||||
|
log.warn("Not supported: Gate activation functions != ActivationSigmoid");
|
||||||
|
}
|
||||||
|
if (!(activationFn instanceof ActivationTanH)) {
|
||||||
|
supported = false;
|
||||||
|
log.warn("Not supported: Layer activation functions != ActivationTanH");
|
||||||
|
}
|
||||||
|
if (hasPeepholeConnections) {
|
||||||
|
supported = false;
|
||||||
|
log.warn("Not supported: LSTM layers with peephole connections");
|
||||||
|
}
|
||||||
|
return supported;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pair<Gradient, INDArray> backpropGradient(final NeuralNetConfiguration conf,
|
||||||
|
final IActivation gateActivationFn, final INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
|
||||||
|
final INDArray inputWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
|
||||||
|
final INDArray epsilon, final boolean truncatedBPTT, final int tbpttBackwardLength,
|
||||||
|
final FwdPassReturn fwdPass, final boolean forwards, final String inputWeightKey,
|
||||||
|
final String recurrentWeightKey, final String biasWeightKey,
|
||||||
|
final Map<String, INDArray> gradientViews, INDArray maskArray, //Input mask: should only be used with bidirectional RNNs + variable length
|
||||||
|
final boolean hasPeepholeConnections, //True for GravesLSTM, false for LSTM
|
||||||
|
final LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
|
||||||
|
//Expect errors to have shape: [miniBatchSize,n^(L+1),timeSeriesLength]
|
||||||
|
val hiddenLayerSize = recurrentWeights.size(0); //i.e., n^L
|
||||||
|
val prevLayerSize = inputWeights.size(0); //n^(L-1)
|
||||||
|
val inputLayerSize = input.size(1);
|
||||||
|
val miniBatchSize = epsilon.size(0);
|
||||||
|
boolean is2dInput = epsilon.rank() < 3; //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]
|
||||||
|
long timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
|
||||||
|
|
||||||
|
INDArray x = toCOrder(input.permute(2, 0, 1));
|
||||||
|
INDArray dy = toCOrder(epsilon.permute(2, 0, 1));
|
||||||
|
INDArray dx = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, inputWeights.dataType(), new long[] {timeSeriesLength, miniBatchSize, prevLayerSize}, 'c');
|
||||||
|
|
||||||
|
INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
|
||||||
|
INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey); //Order: {I,F,O,G}
|
||||||
|
INDArray bGradientsOut = gradientViews.get(biasWeightKey);
|
||||||
|
|
||||||
|
INDArray outputActivations = toCOrder(fwdPass.fwdPassOutput.permute(2, 0, 1));
|
||||||
|
INDArray prevStepMemCellState = toCOrder(fwdPass.prevMemCell);
|
||||||
|
INDArray prevStepActivations = toCOrder(fwdPass.prevAct);
|
||||||
|
|
||||||
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareActionAllWrite(x, dy, dx, outputActivations,
|
||||||
|
prevStepMemCellState, prevStepActivations, iwGradientsOut, rwGradientsOut, bGradientsOut);
|
||||||
|
Pointer xData = allocator.getPointer(x, context);
|
||||||
|
Pointer dyData = allocator.getPointer(dy, context);
|
||||||
|
Pointer dxData = allocator.getPointer(dx, context);
|
||||||
|
Pointer outputActivationsData = allocator.getPointer(outputActivations, context);
|
||||||
|
Pointer prevMemCellStateData = allocator.getPointer(prevStepMemCellState, context);
|
||||||
|
Pointer prevStepActivationsData = allocator.getPointer(prevStepActivations, context);
|
||||||
|
Pointer iwGradientsOutData = allocator.getPointer(iwGradientsOut, context);
|
||||||
|
Pointer rwGradientsOutData = allocator.getPointer(rwGradientsOut, context);
|
||||||
|
Pointer bGradientsOutData = allocator.getPointer(bGradientsOut, context);
|
||||||
|
|
||||||
|
CUstream_st stream = new CUstream_st(context.getCublasStream());
|
||||||
|
checkCudnn(cudnnSetStream(cudnnContext, stream));
|
||||||
|
|
||||||
|
if (truncatedBPTT) {
|
||||||
|
val endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength) * miniBatchSize * hiddenLayerSize;
|
||||||
|
xData.position(endIdx * dataTypeSize);
|
||||||
|
dyData.position(endIdx * (BIDIRECTIONAL ? 2 : 1) * dataTypeSize);
|
||||||
|
outputActivationsData.position(endIdx * (BIDIRECTIONAL ? 2 : 1) * dataTypeSize);
|
||||||
|
timeSeriesLength = (int) Math.min(timeSeriesLength, tbpttBackwardLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnnTensorStruct xDesc0 = xDesc.get(cudnnTensorStruct.class, 0);
|
||||||
|
|
||||||
|
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
|
||||||
|
checkCudnn(cudnnRNNBackwardData(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, yDesc,
|
||||||
|
outputActivationsData, dyDesc, dyData, cudnnContext.dhyDesc, null, cudnnContext.dcyDesc, null,
|
||||||
|
cudnnContext.wDesc, weightsSpace, cudnnContext.hxDesc, prevStepActivationsData, //hx: initial hidden state of RNN
|
||||||
|
cudnnContext.cxDesc, prevMemCellStateData, //cx: initial cell state of RNN
|
||||||
|
dxDesc, dxData, //dx: gradient at input of each time step
|
||||||
|
cudnnContext.dhxDesc, null, //dhx: gradient at initial hidden state of RNN
|
||||||
|
cudnnContext.dcxDesc, null, //dcx: Gradient at initial cell state
|
||||||
|
workSpace, workSpace.limit(), reserveSpace, reserveSpace.limit()));
|
||||||
|
|
||||||
|
// cudnnRNNBackwardWeights adds to the data in dW.
|
||||||
|
checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream));
|
||||||
|
|
||||||
|
checkCudnn(cudnnRNNBackwardWeights(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData, //Input data
|
||||||
|
cudnnContext.hxDesc, prevStepActivationsData, //Initial hidden state
|
||||||
|
yDesc, outputActivationsData, //Output data
|
||||||
|
workSpace, workSpace.limit(), cudnnContext.dwDesc, weightsSpace, reserveSpace,
|
||||||
|
reserveSpace.limit()));
|
||||||
|
|
||||||
|
int[] dataType = new int[1];
|
||||||
|
int[] format = new int[1];
|
||||||
|
int[] nbDims = new int[1];
|
||||||
|
int[] filterDimA = new int[3];
|
||||||
|
Pointer linLayerMat = new Pointer();
|
||||||
|
Pointer linLayerBias = new Pointer();
|
||||||
|
|
||||||
|
for (int layer = 0; layer < NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1); layer++) {
|
||||||
|
for (int linLayerID = 0; linLayerID < NUM_LINEAR_LAYERS; linLayerID++) {
|
||||||
|
checkCudnn(cudnnGetRNNLinLayerMatrixParams(cudnnContext, cudnnContext.rnnDesc, layer, xDesc0,
|
||||||
|
cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerMatDesc,
|
||||||
|
linLayerMat));
|
||||||
|
|
||||||
|
checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerMatDesc, 3, dataType, format, nbDims,
|
||||||
|
filterDimA));
|
||||||
|
|
||||||
|
checkCudnn(cudnnGetRNNLinLayerBiasParams(cudnnContext, cudnnContext.rnnDesc, layer, xDesc0,
|
||||||
|
cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerBiasDesc,
|
||||||
|
linLayerBias));
|
||||||
|
|
||||||
|
checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerBiasDesc, 3, dataType, format, nbDims,
|
||||||
|
filterDimA));
|
||||||
|
|
||||||
|
// our data is in "new, forget, output, and input gates" order (aka IFOG), each kind of weight packed together
|
||||||
|
int position = 0;
|
||||||
|
long size = 0;
|
||||||
|
Pointer data = null;
|
||||||
|
switch (linLayerID) {
|
||||||
|
case 0:
|
||||||
|
data = iwGradientsOutData;
|
||||||
|
position = 3;
|
||||||
|
size = inputLayerSize;
|
||||||
|
break; // input gate
|
||||||
|
case 1:
|
||||||
|
data = iwGradientsOutData;
|
||||||
|
position = 1;
|
||||||
|
size = inputLayerSize;
|
||||||
|
break; // forget gate
|
||||||
|
case 2:
|
||||||
|
data = iwGradientsOutData;
|
||||||
|
position = 0;
|
||||||
|
size = inputLayerSize;
|
||||||
|
break; // new gate (input modulation gate)
|
||||||
|
case 3:
|
||||||
|
data = iwGradientsOutData;
|
||||||
|
position = 2;
|
||||||
|
size = inputLayerSize;
|
||||||
|
break; // output gate
|
||||||
|
case 4:
|
||||||
|
data = rwGradientsOutData;
|
||||||
|
position = 3;
|
||||||
|
size = hiddenLayerSize;
|
||||||
|
break; // input gate
|
||||||
|
case 5:
|
||||||
|
data = rwGradientsOutData;
|
||||||
|
position = 1;
|
||||||
|
size = hiddenLayerSize;
|
||||||
|
break; // forget gate
|
||||||
|
case 6:
|
||||||
|
data = rwGradientsOutData;
|
||||||
|
position = 0;
|
||||||
|
size = hiddenLayerSize;
|
||||||
|
break; // new gate (input modulation gate)
|
||||||
|
case 7:
|
||||||
|
data = rwGradientsOutData;
|
||||||
|
position = 2;
|
||||||
|
size = hiddenLayerSize;
|
||||||
|
break; // output gate
|
||||||
|
default:
|
||||||
|
throw new RuntimeException();
|
||||||
|
}
|
||||||
|
checkCuda(cudaMemcpyAsync(data.position(position * size * hiddenLayerSize * dataTypeSize), linLayerMat,
|
||||||
|
size * hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
if (linLayerID < 4) {
|
||||||
|
checkCuda(cudaMemcpyAsync(bGradientsOutData.position(position * hiddenLayerSize * dataTypeSize),
|
||||||
|
linLayerBias, hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allocator.getFlowController().registerActionAllWrite(context, x, dy, dx, outputActivations,
|
||||||
|
prevStepMemCellState, prevStepActivations, iwGradientsOut, rwGradientsOut, bGradientsOut);
|
||||||
|
|
||||||
|
Gradient retGradient = new DefaultGradient();
|
||||||
|
retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
|
||||||
|
retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
|
||||||
|
retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
|
||||||
|
|
||||||
|
INDArray epsilonNext = dx.permute(1, 2, 0); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]
|
||||||
|
|
||||||
|
return new Pair<>(retGradient, epsilonNext);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FwdPassReturn activate(final Layer layer, final NeuralNetConfiguration conf,
|
||||||
|
final IActivation gateActivationFn, //Activation function for the gates - sigmoid or hard sigmoid (must be found in range 0 to 1)
|
||||||
|
INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
|
||||||
|
final INDArray inputWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
|
||||||
|
final INDArray biases, //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
|
||||||
|
final boolean training, final INDArray prevOutputActivations, final INDArray prevMemCellState,
|
||||||
|
boolean forBackprop, boolean forwards, final String inputWeightKey, INDArray maskArray, //Input mask: should only be used with bidirectional RNNs + variable length
|
||||||
|
final boolean hasPeepholeConnections, //True for GravesLSTM, false for LSTM
|
||||||
|
final LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
|
||||||
|
boolean is2dInput = input.rank() < 3; //Edge case of T=1, may have shape [m,nIn], equiv. to [m,nIn,1]
|
||||||
|
val timeSeriesLength = (is2dInput ? 1 : input.size(2));
|
||||||
|
val hiddenLayerSize = recurrentWeights.size(0);
|
||||||
|
val miniBatchSize = input.size(0);
|
||||||
|
val inputLayerSize = input.size(1);
|
||||||
|
|
||||||
|
INDArray x = toCOrder(input.permute(2, 0, 1));
|
||||||
|
INDArray linInputWeights = inputWeights;
|
||||||
|
INDArray linRecurrentWeights = recurrentWeights;
|
||||||
|
INDArray linBiases = biases;
|
||||||
|
|
||||||
|
INDArray prevAct = toCOrder(prevOutputActivations);
|
||||||
|
INDArray prevMemCell = toCOrder(prevMemCellState);
|
||||||
|
|
||||||
|
INDArray outputActivations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS,
|
||||||
|
inputWeights.dataType(), new long[] {timeSeriesLength, miniBatchSize, hiddenLayerSize * (BIDIRECTIONAL ? 2 : 1)}, 'c');
|
||||||
|
INDArray finalMemCellState = Nd4j.createUninitialized( inputWeights.dataType(),
|
||||||
|
new long[] {/*numLayers * (bidirectional ? 2 : 1),*/ miniBatchSize, hiddenLayerSize}, 'c');
|
||||||
|
INDArray finalStepActivations = Nd4j.createUninitialized( inputWeights.dataType(),
|
||||||
|
new long[] {/*numLayers * (bidirectional ? 2 : 1),*/ miniBatchSize, hiddenLayerSize}, 'c');
|
||||||
|
|
||||||
|
FwdPassReturn toReturn = new FwdPassReturn();
|
||||||
|
toReturn.prevAct = prevAct;
|
||||||
|
toReturn.prevMemCell = prevMemCell;
|
||||||
|
|
||||||
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if (timeSeriesLength > xDesc.capacity()) {
|
||||||
|
xDesc.deallocate();
|
||||||
|
xDesc = new TensorArray(timeSeriesLength);
|
||||||
|
}
|
||||||
|
if (timeSeriesLength > yDesc.capacity()) {
|
||||||
|
yDesc.deallocate();
|
||||||
|
yDesc = new TensorArray(timeSeriesLength);
|
||||||
|
}
|
||||||
|
if (timeSeriesLength > dxDesc.capacity()) {
|
||||||
|
dxDesc.deallocate();
|
||||||
|
dxDesc = new TensorArray(timeSeriesLength);
|
||||||
|
}
|
||||||
|
if (timeSeriesLength > dyDesc.capacity()) {
|
||||||
|
dyDesc.deallocate();
|
||||||
|
dyDesc = new TensorArray(timeSeriesLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < timeSeriesLength; i++) {
|
||||||
|
int[] dimA = {(int) miniBatchSize, (int) inputLayerSize, 1};
|
||||||
|
int[] strideA = {(int) dimA[2] * dimA[1], dimA[2], 1};
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(xDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimA, strideA));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(dxDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimA, strideA));
|
||||||
|
|
||||||
|
int[] dimB = {(int) miniBatchSize, (int) hiddenLayerSize * (BIDIRECTIONAL ? 2 : 1), 1};
|
||||||
|
int[] strideB = {dimB[2] * dimB[1], dimB[2], 1};
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(yDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimB, strideB));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(dyDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimB, strideB));
|
||||||
|
}
|
||||||
|
|
||||||
|
int[] dimC = {NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1), (int) miniBatchSize, (int) hiddenLayerSize};
|
||||||
|
int[] strideC = {dimC[2] * dimC[1], dimC[2], 1};
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.hxDesc, dataType, 3, dimC, strideC));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.cxDesc, dataType, 3, dimC, strideC));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.hyDesc, dataType, 3, dimC, strideC));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.cyDesc, dataType, 3, dimC, strideC));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dhxDesc, dataType, 3, dimC, strideC));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dcxDesc, dataType, 3, dimC, strideC));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dhyDesc, dataType, 3, dimC, strideC));
|
||||||
|
checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dcyDesc, dataType, 3, dimC, strideC));
|
||||||
|
|
||||||
|
checkCudnn(cudnnDropoutGetStatesSize(cudnnContext, sizeInBytes));
|
||||||
|
long stateSize = sizeInBytes.get(0);
|
||||||
|
if (stateSize > stateSpace.capacity()) {
|
||||||
|
stateSpace.deallocate();
|
||||||
|
stateSpace = new DataCache(stateSize);
|
||||||
|
}
|
||||||
|
stateSpace.limit(stateSize);
|
||||||
|
|
||||||
|
if(!initializedDropoutDescriptor) {
|
||||||
|
checkCudnn(cudnnSetDropoutDescriptor(cudnnContext.dropoutDesc, cudnnContext, DROPOUT, stateSpace, stateSize,
|
||||||
|
Nd4j.getRandom().getSeed()));
|
||||||
|
}
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetRNNDescriptor_v6(cudnnContext, cudnnContext.rnnDesc, (int) hiddenLayerSize, NUM_LAYERS, cudnnContext.dropoutDesc,
|
||||||
|
CUDNN_LINEAR_INPUT, BIDIRECTIONAL ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, RNN_MODE,
|
||||||
|
CUDNN_RNN_ALGO_STANDARD, dataType));
|
||||||
|
|
||||||
|
cudnnTensorStruct xDesc0 = xDesc.get(cudnnTensorStruct.class, 0);
|
||||||
|
checkCudnn(cudnnGetRNNParamsSize(cudnnContext, cudnnContext.rnnDesc, xDesc0, sizeInBytes, dataType));
|
||||||
|
long weightsSize = sizeInBytes.get(0);
|
||||||
|
if (weightsSize > weightsSpace.capacity()) {
|
||||||
|
weightsSpace.deallocate();
|
||||||
|
weightsSpace = new DataCache(weightsSize);
|
||||||
|
}
|
||||||
|
weightsSpace.limit(weightsSize);
|
||||||
|
|
||||||
|
int[] dimW = {(int) weightsSize / dataTypeSize, 1, 1};
|
||||||
|
|
||||||
|
checkCudnn(cudnnSetFilterNdDescriptor(cudnnContext.wDesc, dataType, CUDNN_TENSOR_NCHW, 3, dimW));
|
||||||
|
checkCudnn(cudnnSetFilterNdDescriptor(cudnnContext.dwDesc, dataType, CUDNN_TENSOR_NCHW, 3, dimW));
|
||||||
|
|
||||||
|
checkCudnn(cudnnGetRNNWorkspaceSize(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, sizeInBytes));
|
||||||
|
long workSize = sizeInBytes.get(0);
|
||||||
|
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
|
||||||
|
if (workSpace == null || workSize > workSpace.capacity()) {
|
||||||
|
if(log.isTraceEnabled()){
|
||||||
|
if(workSpace == null){
|
||||||
|
log.trace("CudnnLSTMHelper activate: Allocating initial workspace of size {} ({})", workSize,
|
||||||
|
BinaryByteUnit.format(workSize, "#.00"));
|
||||||
|
} else {
|
||||||
|
log.trace("CudnnLSTMHelper activate: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})",
|
||||||
|
workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"),
|
||||||
|
workSize, BinaryByteUnit.format(workSize, "#.00"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(workSpace != null)
|
||||||
|
workSpace.deallocate();
|
||||||
|
workSpace = new DataCache(workSize);
|
||||||
|
workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
|
||||||
|
}
|
||||||
|
workSpace.limit(workSize);
|
||||||
|
|
||||||
|
checkCudnn(cudnnGetRNNTrainingReserveSize(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc,
|
||||||
|
sizeInBytes));
|
||||||
|
long reserveSize = sizeInBytes.get(0);
|
||||||
|
if (reserveSize > reserveSpace.capacity()) {
|
||||||
|
reserveSpace.deallocate();
|
||||||
|
reserveSpace = new DataCache(reserveSize);
|
||||||
|
}
|
||||||
|
reserveSpace.limit(reserveSize);
|
||||||
|
|
||||||
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
CudaContext context = allocator.getFlowController().prepareActionAllWrite(x, linInputWeights,
|
||||||
|
linRecurrentWeights, linBiases, prevAct, prevMemCell, outputActivations, finalMemCellState,
|
||||||
|
finalStepActivations);
|
||||||
|
Pointer xData = allocator.getPointer(x, context);
|
||||||
|
Pointer linInputWeightsData = allocator.getPointer(linInputWeights, context);
|
||||||
|
Pointer linRecurrentWeightsData = allocator.getPointer(linRecurrentWeights, context);
|
||||||
|
Pointer linBiasesData = allocator.getPointer(linBiases, context);
|
||||||
|
Pointer prevActData = allocator.getPointer(prevAct, context);
|
||||||
|
Pointer prevMemCellData = allocator.getPointer(prevMemCell, context);
|
||||||
|
Pointer outputActivationsData = allocator.getPointer(outputActivations, context);
|
||||||
|
Pointer finalMemCellStateData = allocator.getPointer(finalMemCellState, context);
|
||||||
|
Pointer finalTimeStepActivationsData = allocator.getPointer(finalStepActivations, context);
|
||||||
|
|
||||||
|
CUstream_st stream = new CUstream_st(context.getCublasStream());
|
||||||
|
checkCudnn(cudnnSetStream(cudnnContext, stream));
|
||||||
|
|
||||||
|
checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream));
|
||||||
|
|
||||||
|
int[] dataType = new int[1];
|
||||||
|
int[] format = new int[1];
|
||||||
|
int[] nbDims = new int[1];
|
||||||
|
int[] filterDimA = new int[3];
|
||||||
|
Pointer linLayerMat = new Pointer();
|
||||||
|
Pointer linLayerBias = new Pointer();
|
||||||
|
|
||||||
|
for (int layerNum = 0; layerNum < NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1); layerNum++) {
|
||||||
|
for (int linLayerID = 0; linLayerID < NUM_LINEAR_LAYERS; linLayerID++) {
|
||||||
|
checkCudnn(cudnnGetRNNLinLayerMatrixParams(cudnnContext, cudnnContext.rnnDesc, layerNum, xDesc0,
|
||||||
|
cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerMatDesc,
|
||||||
|
linLayerMat));
|
||||||
|
|
||||||
|
checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerMatDesc, 3, dataType, format, nbDims,
|
||||||
|
filterDimA));
|
||||||
|
|
||||||
|
checkCudnn(cudnnGetRNNLinLayerBiasParams(cudnnContext, cudnnContext.rnnDesc, layerNum, xDesc0,
|
||||||
|
cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerBiasDesc,
|
||||||
|
linLayerBias));
|
||||||
|
|
||||||
|
checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerBiasDesc, 3, dataType, format, nbDims,
|
||||||
|
filterDimA));
|
||||||
|
|
||||||
|
// our data is in "new, forget, output, and input gates" order (aka IFOG), each kind of weight packed together
|
||||||
|
int position = 0;
|
||||||
|
long size = 0;
|
||||||
|
Pointer data = null;
|
||||||
|
switch (linLayerID) {
|
||||||
|
case 0:
|
||||||
|
data = linInputWeightsData;
|
||||||
|
position = 3;
|
||||||
|
size = inputLayerSize;
|
||||||
|
break; // input gate
|
||||||
|
case 1:
|
||||||
|
data = linInputWeightsData;
|
||||||
|
position = 1;
|
||||||
|
size = inputLayerSize;
|
||||||
|
break; // forget gate
|
||||||
|
case 2:
|
||||||
|
data = linInputWeightsData;
|
||||||
|
position = 0;
|
||||||
|
size = inputLayerSize;
|
||||||
|
break; // new gate
|
||||||
|
case 3:
|
||||||
|
data = linInputWeightsData;
|
||||||
|
position = 2;
|
||||||
|
size = inputLayerSize;
|
||||||
|
break; // output gate
|
||||||
|
case 4:
|
||||||
|
data = linRecurrentWeightsData;
|
||||||
|
position = 3;
|
||||||
|
size = hiddenLayerSize;
|
||||||
|
break; // input gate
|
||||||
|
case 5:
|
||||||
|
data = linRecurrentWeightsData;
|
||||||
|
position = 1;
|
||||||
|
size = hiddenLayerSize;
|
||||||
|
break; // forget gate
|
||||||
|
case 6:
|
||||||
|
data = linRecurrentWeightsData;
|
||||||
|
position = 0;
|
||||||
|
size = hiddenLayerSize;
|
||||||
|
break; // new gate
|
||||||
|
case 7:
|
||||||
|
data = linRecurrentWeightsData;
|
||||||
|
position = 2;
|
||||||
|
size = hiddenLayerSize;
|
||||||
|
break; // output gate
|
||||||
|
default:
|
||||||
|
throw new RuntimeException();
|
||||||
|
}
|
||||||
|
checkCuda(cudaMemcpyAsync(linLayerMat, data.position(position * size * hiddenLayerSize * dataTypeSize),
|
||||||
|
size * hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
if (linLayerID < 4) {
|
||||||
|
checkCuda(cudaMemcpyAsync(linLayerBias,
|
||||||
|
linBiasesData.position(position * hiddenLayerSize * dataTypeSize),
|
||||||
|
hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (training) {
|
||||||
|
checkCudnn(cudnnRNNForwardTraining(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData,
|
||||||
|
cudnnContext.hxDesc, prevActData, cudnnContext.cxDesc, prevMemCellData, cudnnContext.wDesc,
|
||||||
|
weightsSpace, yDesc, outputActivationsData, cudnnContext.hyDesc,
|
||||||
|
finalTimeStepActivationsData, cudnnContext.cyDesc, finalMemCellStateData, workSpace,
|
||||||
|
workSpace.limit(), reserveSpace, reserveSpace.limit()));
|
||||||
|
} else {
|
||||||
|
checkCudnn(cudnnRNNForwardInference(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData,
|
||||||
|
cudnnContext.hxDesc, prevActData, cudnnContext.cxDesc, prevMemCellData, cudnnContext.wDesc,
|
||||||
|
weightsSpace, yDesc, outputActivationsData, cudnnContext.hyDesc,
|
||||||
|
finalTimeStepActivationsData, cudnnContext.cyDesc, finalMemCellStateData, workSpace,
|
||||||
|
workSpace.limit()));
|
||||||
|
}
|
||||||
|
|
||||||
|
allocator.getFlowController().registerActionAllWrite(context, x, linInputWeights, linRecurrentWeights,
|
||||||
|
linBiases, prevAct, prevMemCell, outputActivations, finalMemCellState, finalStepActivations);
|
||||||
|
|
||||||
|
toReturn.fwdPassOutput = outputActivations.permute(1, 2, 0);
|
||||||
|
toReturn.lastAct = finalStepActivations;
|
||||||
|
toReturn.lastMemCell = finalMemCellState;
|
||||||
|
toReturn.prevAct = prevAct;
|
||||||
|
toReturn.prevMemCell = prevMemCell;
|
||||||
|
|
||||||
|
return toReturn;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Long> helperMemoryUse() {
|
||||||
|
Map<String,Long> memUse = new HashMap<>();
|
||||||
|
memUse.put("stateStace", stateSpace.capacity());
|
||||||
|
memUse.put("reserveSpace", reserveSpace.capacity());
|
||||||
|
memUse.put("weightsSpace", weightsSpace.capacity());
|
||||||
|
return memUse;
|
||||||
|
}
|
||||||
|
}
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.modelimport.keras.preprocessors;
|
package org.deeplearning4j.nn.modelimport.keras.preprocessors;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
|
@ -32,6 +33,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper=false)
|
||||||
public class KerasFlattenRnnPreprocessor extends BaseInputPreProcessor {
|
public class KerasFlattenRnnPreprocessor extends BaseInputPreProcessor {
|
||||||
|
|
||||||
private long tsLength;
|
private long tsLength;
|
||||||
|
|
|
@ -1,31 +1,24 @@
|
||||||
plugins {
|
plugins {
|
||||||
id 'java-library'
|
id 'java-library'
|
||||||
id 'maven-publish'
|
id 'maven-publish'
|
||||||
|
id 'com.github.johnrengelman.shadow' version '7.1.2'
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
apply from: rootProject.projectDir.path+"/chooseBackend.gradle"
|
||||||
configurations.archives.artifacts.with { archives ->
|
|
||||||
|
|
||||||
archives.each {
|
|
||||||
println(it.name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
afterEvaluate {
|
||||||
//Todo clean this
|
//Todo clean this
|
||||||
api platform(project(":cavis-common-platform"))
|
api platform(project(":cavis-common-platform"))
|
||||||
//api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise
|
//api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise
|
||||||
api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5"
|
//api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5"
|
||||||
api 'org.slf4j:slf4j-simple:2.0.3'
|
//api 'org.slf4j:slf4j-simple:2.0.3'
|
||||||
api 'org.slf4j:slf4j-api:2.0.3'
|
//api 'org.slf4j:slf4j-api:2.0.3'
|
||||||
//TODO for the two below.. either platform specific uber jars or a single big one with all platforms
|
//TODO for the two below.. either platform specific uber jars or a single big one with all platforms
|
||||||
api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64"
|
//api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64"
|
||||||
//api group: "org.bytedeco", name: "javacpp", version: "1.5.7"
|
|
||||||
// api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu"
|
|
||||||
//api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT'
|
|
||||||
rootProject.getAllprojects().each { Project sproj ->
|
rootProject.getAllprojects().each { Project sproj ->
|
||||||
if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform")
|
if (!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform")
|
||||||
&& !sproj.name.equals("Cavis")
|
&& !sproj.name.equals("Cavis")
|
||||||
&& !sproj.name.equals("cavis-datavec")
|
&& !sproj.name.equals("cavis-datavec")
|
||||||
&& !sproj.name.equals("cavis-dnn")
|
&& !sproj.name.equals("cavis-dnn")
|
||||||
|
@ -33,26 +26,41 @@ dependencies {
|
||||||
&& !sproj.name.equals("cavis-nd4j")
|
&& !sproj.name.equals("cavis-nd4j")
|
||||||
&& !sproj.name.equals("cavis-ui")
|
&& !sproj.name.equals("cavis-ui")
|
||||||
&& !sproj.name.equals("cavis-zoo")) {
|
&& !sproj.name.equals("cavis-zoo")) {
|
||||||
//compileOnly project(""+sproj.path)
|
api project(path: sproj.path, configuration: 'runtimeElements')
|
||||||
api sproj
|
}
|
||||||
if(! sproj.configurations.empty) {
|
}
|
||||||
//compileOnly project(sproj.getPath())
|
// if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements")
|
||||||
|
// if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportApiElements")
|
||||||
|
/*
|
||||||
|
api(projects.cavisNative.cavisNativeLib) {
|
||||||
|
capabilities {
|
||||||
|
//if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version)
|
||||||
|
if (withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
api(projects.cavisNative.cavisNativeLib) {
|
||||||
|
capabilities {
|
||||||
|
if (withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version)
|
||||||
|
//if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
//if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation")
|
||||||
|
//if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation")
|
||||||
|
//if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath")
|
||||||
|
|
||||||
/*
|
/*
|
||||||
sproj.configurations.each {Configuration conf ->
|
api (project(':cavis-native:cavis-native-lib')) {
|
||||||
conf.dependencies.each {Dependency dep ->
|
capabilities {
|
||||||
compileOnly dep
|
if(withCpu()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cpu-support")
|
||||||
|
//if(withCuda()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cuda-support")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
tasks.getByName("jar") {
|
tasks.getByName("jar") {
|
||||||
|
|
||||||
|
@ -77,19 +85,39 @@ tasks.getByName("jar") {
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*
|
|
||||||
|
|
||||||
/*
|
|
||||||
artifacts {
|
|
||||||
archives customFatJar
|
|
||||||
}
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
artifacts {
|
||||||
|
archives shadowJar
|
||||||
|
}
|
||||||
|
|
||||||
|
shadowJar {
|
||||||
|
enabled true;
|
||||||
|
zip64 true //need this to support jars with more than 65535 entries
|
||||||
|
archiveClassifier.set('')
|
||||||
|
}
|
||||||
|
|
||||||
publishing {
|
publishing {
|
||||||
publications {
|
publications {
|
||||||
mavenJava(MavenPublication) {
|
/*mavenJava(MavenPublication) {
|
||||||
// artifact customFatJar
|
//artifact customFatJar
|
||||||
// from components.java
|
// from components.java
|
||||||
|
/* pom.withXml {
|
||||||
|
def dependenciesNode = asNode().dependencies
|
||||||
|
def dependencyNode = dependenciesNode.appendNode()
|
||||||
|
|
||||||
|
dependencyNode.appendNode('groupId', 'net.brutex.cavis')
|
||||||
|
dependencyNode.appendNode('artifactId', 'cavis-native-lib')
|
||||||
|
dependencyNode.appendNode('version', '1.0.0-SNAPSHOT')
|
||||||
|
//dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu')
|
||||||
|
//dependencyNode.appendNode('scope', 'compile')
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
shadow(MavenPublication) { publication ->
|
||||||
|
project.shadow.component(publication)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,8 @@ ext {
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation platform(projects.cavisCommonPlatform)
|
implementation platform(projects.cavisCommonPlatform)
|
||||||
|
|
||||||
implementation project(":cavis-native:cavis-native-blas")
|
//implementation project(":cavis-native:cavis-native-blas")
|
||||||
|
implementation projects.cavisNative.cavisNativeBlas
|
||||||
|
|
||||||
implementation group: "org.bytedeco", name: "cuda"
|
implementation group: "org.bytedeco", name: "cuda"
|
||||||
implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget
|
implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget
|
||||||
|
|
|
@ -121,7 +121,7 @@ endfunction()
|
||||||
|
|
||||||
if (SD_CUDA)
|
if (SD_CUDA)
|
||||||
#enable_language(CUDA)
|
#enable_language(CUDA)
|
||||||
find_package(CUDAToolkit 11.2 REQUIRED)
|
find_package(CUDAToolkit 11.4 REQUIRED)
|
||||||
message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}")
|
message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}")
|
||||||
message(STATUS "CUDAToolkit_VERSION_MAJOR: ${CUDAToolkit_VERSION_MAJOR}")
|
message(STATUS "CUDAToolkit_VERSION_MAJOR: ${CUDAToolkit_VERSION_MAJOR}")
|
||||||
message(STATUS "CUDAToolkit_VERSION_MINOR: ${CUDAToolkit_VERSION_MINOR}")
|
message(STATUS "CUDAToolkit_VERSION_MINOR: ${CUDAToolkit_VERSION_MINOR}")
|
||||||
|
|
|
@ -20,11 +20,9 @@
|
||||||
*/
|
*/
|
||||||
ext {
|
ext {
|
||||||
chip = (properties.CAVIS_CHIP ?: "cuda,cpu").toLowerCase() //the default is to build for CPU and CUDA
|
chip = (properties.CAVIS_CHIP ?: "cuda,cpu").toLowerCase() //the default is to build for CPU and CUDA
|
||||||
testChip = (properties.CAVIS_TEST_CHIP ?: " ").toLowerCase() //the default is without specific backend
|
logger.debug("Building for chips ${chip} and running tests with backends for ${chip}")
|
||||||
logger.debug("Building for chips ${chip} and running tests with backends for ${testChip}")
|
|
||||||
|
|
||||||
chipList = chip.split(",")
|
chipList = chip.split(",")
|
||||||
testChipList = testChip.split(",")
|
|
||||||
|
|
||||||
/* just for usability */
|
/* just for usability */
|
||||||
withCuda = { ->
|
withCuda = { ->
|
||||||
|
@ -33,10 +31,4 @@ ext {
|
||||||
withCpu = { ->
|
withCpu = { ->
|
||||||
return chip.contains("cpu")
|
return chip.contains("cpu")
|
||||||
}
|
}
|
||||||
withCudaTest = { ->
|
|
||||||
return testChip.contains("cuda")
|
|
||||||
}
|
|
||||||
withCpuTest = { ->
|
|
||||||
return testChip.contains("cpu")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ ext {
|
||||||
buildTarget = rootProject.ext.buildTarget
|
buildTarget = rootProject.ext.buildTarget
|
||||||
apply from: new File("${project.rootProject.projectDir}/chooseBackend.gradle")
|
apply from: new File("${project.rootProject.projectDir}/chooseBackend.gradle")
|
||||||
|
|
||||||
testChipList.each { thisChip ->
|
chipList.each { thisChip ->
|
||||||
configurations.register("${thisChip}TestImplementation") {
|
configurations.register("${thisChip}TestImplementation") {
|
||||||
it.extendsFrom configurations.testImplementation, configurations.implementation
|
it.extendsFrom configurations.testImplementation, configurations.implementation
|
||||||
|
|
||||||
|
@ -79,33 +79,44 @@ ext {
|
||||||
|
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
if (withCudaTest()) {
|
if (withCuda()) {
|
||||||
cudaTestRuntime platform(projects.cavisCommonPlatform)
|
cudaTestRuntime platform(projects.cavisCommonPlatform)
|
||||||
cudaTestRuntime projects.cavisNative.cavisNativeJcublas
|
cudaTestRuntime projects.cavisNative.cavisNativeJcublas
|
||||||
|
cudaTestRuntime projects.cavisDnn.cavisDnnCudnn
|
||||||
cudaTestRuntime group: "org.bytedeco", name: "openblas"
|
cudaTestRuntime group: "org.bytedeco", name: "openblas"
|
||||||
cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
|
cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
|
||||||
cudaTestRuntime group: "org.bytedeco", name: "cuda"
|
cudaTestRuntime group: "org.bytedeco", name: "cuda"
|
||||||
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget
|
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget
|
||||||
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
|
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
|
||||||
|
cudaTestRuntime (project( path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements"))
|
||||||
|
/*
|
||||||
cudaTestRuntime(project(":cavis-native:cavis-native-lib")) {
|
cudaTestRuntime(project(":cavis-native:cavis-native-lib")) {
|
||||||
|
|
||||||
capabilities {
|
capabilities {
|
||||||
it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cuda-support:1.0.0-SNAPSHOT"
|
it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cuda-support:1.0.0-SNAPSHOT"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
if (withCpuTest()) {
|
if (withCpu()) {
|
||||||
cpuTestRuntime platform(projects.cavisCommonPlatform)
|
cpuTestRuntime platform(projects.cavisCommonPlatform)
|
||||||
cpuTestRuntime projects.cavisNative.cavisNativeCpu
|
cpuTestRuntime projects.cavisNative.cavisNativeCpu
|
||||||
cpuTestRuntime group: "org.bytedeco", name: "openblas"
|
cpuTestRuntime group: "org.bytedeco", name: "openblas"
|
||||||
cpuTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
|
cpuTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
|
||||||
cpuTestRuntime group: "org.bytedeco", name: "opencv"
|
cpuTestRuntime group: "org.bytedeco", name: "opencv"
|
||||||
cpuTestRuntime group: "org.bytedeco", name: "opencv", classifier: buildTarget
|
cpuTestRuntime group: "org.bytedeco", name: "opencv", classifier: buildTarget
|
||||||
|
cpuTestRuntime project( path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportRuntimeElements")
|
||||||
|
/*
|
||||||
cpuTestRuntime(project(":cavis-native:cavis-native-lib")) {
|
cpuTestRuntime(project(":cavis-native:cavis-native-lib")) {
|
||||||
|
|
||||||
capabilities {
|
capabilities {
|
||||||
it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cpu-support:1.0.0-SNAPSHOT"
|
it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cpu-support:1.0.0-SNAPSHOT"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -89,6 +89,7 @@ include ':cavis-native:cavis-native-lib'
|
||||||
include ':cavis-native:cavis-native-common'
|
include ':cavis-native:cavis-native-common'
|
||||||
include ':cavis-dnn'
|
include ':cavis-dnn'
|
||||||
include ':cavis-dnn:cavis-dnn-api'
|
include ':cavis-dnn:cavis-dnn-api'
|
||||||
|
if(withCuda()) { include ':cavis-dnn:cavis-dnn-cudnn' }
|
||||||
include ':cavis-dnn:cavis-dnn-common'
|
include ':cavis-dnn:cavis-dnn-common'
|
||||||
include ':cavis-dnn:cavis-dnn-common-tests'
|
include ':cavis-dnn:cavis-dnn-common-tests'
|
||||||
include ':cavis-dnn:cavis-dnn-core'
|
include ':cavis-dnn:cavis-dnn-core'
|
||||||
|
@ -116,6 +117,7 @@ include ':cavis-dnn:cavis-dnn-spark:cavis-dnn-spark-parameterserver'
|
||||||
include ':cavis-dnn:cavis-dnn-tsne'
|
include ':cavis-dnn:cavis-dnn-tsne'
|
||||||
include ':cavis-datavec'
|
include ':cavis-datavec'
|
||||||
include ':cavis-datavec:cavis-datavec-api'
|
include ':cavis-datavec:cavis-datavec-api'
|
||||||
|
include ':cavis-datavec:dvec-api'
|
||||||
include ':cavis-datavec:cavis-datavec-data'
|
include ':cavis-datavec:cavis-datavec-data'
|
||||||
include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-arrow'
|
include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-arrow'
|
||||||
include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-image'
|
include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-image'
|
||||||
|
@ -151,3 +153,4 @@ include ':cavis-zoo:cavis-zoo-models'
|
||||||
include ':brutex-extended-tests'
|
include ':brutex-extended-tests'
|
||||||
include ':cavis-full'
|
include ':cavis-full'
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue