Snapshot update (#8194)

* fix double consumption of rng on cpu

Signed-off-by: raver119 <raver119@gmail.com>

* Shyrma docs (#222)

* - documenting and profiling matrix_set_diag cuda kernel

Signed-off-by: Yurii <yurii@skymind.io>

* - correct formula of pnorm pooling in cuda 2d/3d kernels
- remove helper matrix_diag which duplicates work of helper matrix_set_diag

Signed-off-by: Yurii <yurii@skymind.io>

* cublasHandle sharing + lock

Signed-off-by: raver119 <raver119@gmail.com>

* cublasHandle sharing + lock

Signed-off-by: raver119 <raver119@gmail.com>

* Documentation from serialization/deserialization in NLP (#221)

* refactoring

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Javadocs

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Javadoc fixed

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Cleanup

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* dedicated lock for getCudaCublasHandle

Signed-off-by: raver119 <raver119@gmail.com>

* Small fixes (#223)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* ELU DL4J fixes (#224)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* javadoc (#225)

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Small test compilation fix (#226)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8182 remove spark version suffix (#227)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] Thread safety (#229)

* sync after cublas*gemm

Signed-off-by: raver119 <raver119@gmail.com>

* mutex for CublasHelper

Signed-off-by: raver119 <raver119@gmail.com>

* don't store cublasHandle in LaunchContext, it's per-device anyway

Signed-off-by: raver119 <raver119@gmail.com>

* some printout

Signed-off-by: raver119 <raver119@gmail.com>

* check for field instead

Signed-off-by: raver119 <raver119@gmail.com>

* pew-pew

Signed-off-by: raver119 <raver119@gmail.com>

* don't release ContextBuffers until device changed

Signed-off-by: raver119 <raver119@gmail.com>

* small tweak

Signed-off-by: raver119 <raver119@gmail.com>

* some logging in sgemm

Signed-off-by: raver119 <raver119@gmail.com>

* stream sync

Signed-off-by: raver119 <raver119@gmail.com>

* some more logging

Signed-off-by: raver119 <raver119@gmail.com>

* some more error checks

Signed-off-by: raver119 <raver119@gmail.com>

* one fancy test

Signed-off-by: raver119 <raver119@gmail.com>

* one fancy test

Signed-off-by: raver119 <raver119@gmail.com>

* minor AffinityManager fix

Signed-off-by: raver119 <raver119@gmail.com>

* cudaEvent error logging improvement

Signed-off-by: raver119 <raver119@gmail.com>

* ConstantHelper thread safety

Signed-off-by: raver119 <raver119@gmail.com>

* - minor corrections in ConstantTadHelper

Signed-off-by: Yurii <yurii@skymind.io>

* ConstantShapeHelper thread safety

Signed-off-by: raver119 <raver119@gmail.com>

* ConstantTadHelper.cu updated

Signed-off-by: raver119 <raver119@gmail.com>

* logging off

Signed-off-by: raver119 <raver119@gmail.com>

* logging off

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-03 22:02:02 +03:00 committed by GitHub
parent 9d03bb9425
commit 7abc574eeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 835 additions and 839 deletions

View File

@ -38,7 +38,7 @@
<dependency> <dependency>
<groupId>org.datavec</groupId> <groupId>org.datavec</groupId>
<artifactId>datavec-spark-inference-server_2.11</artifactId> <artifactId>datavec-spark-inference-server_2.11</artifactId>
<version>1.0.0_spark_2-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>

View File

@ -25,7 +25,7 @@
<artifactId>datavec-spark-inference-server_2.11</artifactId> <artifactId>datavec-spark-inference-server_2.11</artifactId>
<packaging>jar</packaging> <packaging>jar</packaging>
<version>1.0.0_spark_2-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
<name>datavec-spark-inference-server</name> <name>datavec-spark-inference-server</name>
<properties> <properties>

View File

@ -24,7 +24,7 @@
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<version>1.0.0_spark_2-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
<artifactId>datavec-spark_2.11</artifactId> <artifactId>datavec-spark_2.11</artifactId>
<properties> <properties>

View File

@ -0,0 +1,63 @@
package org.deeplearning4j;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.concurrent.CountDownLatch;
@Ignore
public class RandomTests {
@Test
public void testReproduce() throws Exception {
final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10)
.activation(Activation.TANH).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10)
.activation(Activation.SOFTMAX).build())
.build();
for (int e = 0; e < 3; e++) {
int nThreads = 10;
final CountDownLatch l = new CountDownLatch(nThreads);
for (int i = 0; i < nThreads; i++) {
final int j = i;
Thread t = new Thread(new Runnable() {
@Override
public void run() {
try {
MultiLayerNetwork net = new MultiLayerNetwork(conf.clone());
net.init();
DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(10, false, 12345), 100);
net.fit(iter);
} catch (Throwable t) {
System.out.println("Thread failed: " + j);
t.printStackTrace();
} finally {
l.countDown();
}
}
});
t.start();
}
l.await();
System.out.println("DONE " + e + "\n");
}
}
}

View File

@ -833,14 +833,14 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
public void testB64_1() throws Exception { public void testB64_1() throws Exception {
String wordA = "night"; String wordA = "night";
String wordB = "night day"; String wordB = "night day";
String encA = WordVectorSerializer.encodeB64(wordA); String encA = WordVectorSerializer.ReadHelper.encodeB64(wordA);
String encB = WordVectorSerializer.encodeB64(wordB); String encB = WordVectorSerializer.ReadHelper.encodeB64(wordB);
assertEquals(wordA, WordVectorSerializer.decodeB64(encA)); assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(encA));
assertEquals(wordB, WordVectorSerializer.decodeB64(encB)); assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(encB));
assertEquals(wordA, WordVectorSerializer.decodeB64(wordA)); assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(wordA));
assertEquals(wordB, WordVectorSerializer.decodeB64(wordB)); assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(wordB));
} }

View File

@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator; import org.apache.commons.io.LineIterator;
import org.apache.commons.io.output.CloseShieldOutputStream; import org.apache.commons.io.output.CloseShieldOutputStream;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
@ -52,7 +51,6 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.util.DL4JFileUtils; import org.deeplearning4j.util.DL4JFileUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.compression.impl.NoOp; import org.nd4j.compression.impl.NoOp;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
@ -68,8 +66,6 @@ import org.nd4j.util.OneTimeLogger;
import java.io.*; import java.io.*;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -78,6 +74,80 @@ import java.util.zip.*;
/** /**
* This is utility class, providing various methods for WordVectors serialization * This is utility class, providing various methods for WordVectors serialization
* *
* List of available serialization methods (please keep this list consistent with source code):
*
* <ul>
* <li>Serializers for Word2Vec:</li>
* {@link #writeWordVectors(WeightLookupTable, File)}
* {@link #writeWordVectors(WeightLookupTable, OutputStream)}
* {@link #writeWord2VecModel(Word2Vec, File)}
* {@link #writeWord2VecModel(Word2Vec, String)}
* {@link #writeWord2VecModel(Word2Vec, OutputStream)}
*
* <li>Deserializers for Word2Vec:</li>
* {@link #readWord2VecModel(File)}
* {@link #readWord2VecModel(String)}
* {@link #readWord2VecModel(File, boolean)}
* {@link #readWord2VecModel(String, boolean)}
* {@link #readAsBinaryNoLineBreaks(File)}
* {@link #readAsBinary(File)}
* {@link #readAsCsv(File)}
* {@link #readBinaryModel(File, boolean, boolean)}
* {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
* {@link #readWord2Vec(String, boolean)}
* {@link #readWord2Vec(File, boolean)}
* {@link #readWord2Vec(InputStream, boolean)}
*
* <li>Serializers for ParaVec:</li>
* {@link #writeParagraphVectors(ParagraphVectors, File)}
* {@link #writeParagraphVectors(ParagraphVectors, String)}
* {@link #writeParagraphVectors(ParagraphVectors, OutputStream)}
*
* <li>Deserializers for ParaVec:</li>
* {@link #readParagraphVectors(File)}
* {@link #readParagraphVectors(String)}
* {@link #readParagraphVectors(InputStream)}
*
* <li>Serializers for GloVe:</li>
* {@link #writeWordVectors(Glove, File)}
* {@link #writeWordVectors(Glove, String)}
* {@link #writeWordVectors(Glove, OutputStream)}
*
* <li>Adapters</li>
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
* {@link #fromPair(Pair)}
* {@link #loadTxt(File)}
*
* <li>Serializers to tSNE format</li>
* {@link #writeTsneFormat(Glove, INDArray, File)}
* {@link #writeTsneFormat(Word2Vec, INDArray, File)}
*
* <li>FastText serializer:</li>
* {@link #writeWordVectors(FastText, File)}
*
* <li>FastText deserializer:</li>
* {@link #readWordVectors(File)}
*
* <li>SequenceVectors serializers:</li>
* {@link #writeSequenceVectors(SequenceVectors, OutputStream)}
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, File)}
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, String)}
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, OutputStream)}
* {@link #writeLookupTable(WeightLookupTable, File)}
* {@link #writeVocabCache(VocabCache, File)}
* {@link #writeVocabCache(VocabCache, OutputStream)}
*
* <li>SequenceVectors deserializers:</li>
* {@link #readSequenceVectors(File, boolean)}
* {@link #readSequenceVectors(String, boolean)}
* {@link #readSequenceVectors(SequenceElementFactory, File)}
* {@link #readSequenceVectors(InputStream, boolean)}
* {@link #readSequenceVectors(SequenceElementFactory, InputStream)}
* {@link #readLookupTable(File)}
* {@link #readLookupTable(InputStream)}
*
* </ul>
*
* @author Adam Gibson * @author Adam Gibson
* @author raver119 * @author raver119
* @author alexander@skymind.io * @author alexander@skymind.io
@ -97,7 +167,7 @@ public class WordVectorSerializer {
* @throws IOException * @throws IOException
* @throws NumberFormatException * @throws NumberFormatException
*/ */
private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { /*private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
InMemoryLookupTable lookupTable; InMemoryLookupTable lookupTable;
VocabCache cache; VocabCache cache;
INDArray syn0; INDArray syn0;
@ -142,7 +212,7 @@ public class WordVectorSerializer {
ret.setLookupTable(lookupTable); ret.setLookupTable(lookupTable);
} }
return ret; return ret;
} }*/
/** /**
* Read a binary word2vec file. * Read a binary word2vec file.
@ -173,8 +243,8 @@ public class WordVectorSerializer {
try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName()) try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName())
? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile)); ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
DataInputStream dis = new DataInputStream(bis)) { DataInputStream dis = new DataInputStream(bis)) {
words = Integer.parseInt(readString(dis)); words = Integer.parseInt(ReadHelper.readString(dis));
size = Integer.parseInt(readString(dis)); size = Integer.parseInt(ReadHelper.readString(dis));
syn0 = Nd4j.create(words, size); syn0 = Nd4j.create(words, size);
cache = new AbstractCache<>(); cache = new AbstractCache<>();
@ -188,11 +258,11 @@ public class WordVectorSerializer {
float[] vector = new float[size]; float[] vector = new float[size];
for (int i = 0; i < words; i++) { for (int i = 0; i < words; i++) {
word = readString(dis); word = ReadHelper.readString(dis);
log.trace("Loading " + word + " with word " + i); log.trace("Loading " + word + " with word " + i);
for (int j = 0; j < size; j++) { for (int j = 0; j < size; j++) {
vector[j] = readFloat(dis); vector[j] = ReadHelper.readFloat(dis);
} }
if (cache.containsWord(word)) if (cache.containsWord(word))
@ -236,64 +306,6 @@ public class WordVectorSerializer {
} }
/**
* Read a float from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param is
* @return
* @throws IOException
*/
public static float readFloat(InputStream is) throws IOException {
byte[] bytes = new byte[4];
is.read(bytes);
return getFloat(bytes);
}
/**
* Read a string from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param b
* @return
* @throws IOException
*/
public static float getFloat(byte[] b) {
int accum = 0;
accum = accum | (b[0] & 0xff) << 0;
accum = accum | (b[1] & 0xff) << 8;
accum = accum | (b[2] & 0xff) << 16;
accum = accum | (b[3] & 0xff) << 24;
return Float.intBitsToFloat(accum);
}
/**
* Read a string from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param dis
* @return
* @throws IOException
*/
public static String readString(DataInputStream dis) throws IOException {
byte[] bytes = new byte[MAX_SIZE];
byte b = dis.readByte();
int i = -1;
StringBuilder sb = new StringBuilder();
while (b != 32 && b != 10) {
i++;
bytes[i] = b;
b = dis.readByte();
if (i == 49) {
sb.append(new String(bytes, "UTF-8"));
i = -1;
bytes = new byte[MAX_SIZE];
}
}
sb.append(new String(bytes, 0, i + 1, "UTF-8"));
return sb.toString();
}
/** /**
* This method writes word vectors to the given path. * This method writes word vectors to the given path.
* Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network. * Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network.
@ -355,7 +367,7 @@ public class WordVectorSerializer {
val builder = new StringBuilder(); val builder = new StringBuilder();
val l = element.getLabel(); val l = element.getLabel();
builder.append(encodeB64(l)).append(" "); builder.append(ReadHelper.encodeB64(l)).append(" ");
val vec = lookupTable.vector(element.getLabel()); val vec = lookupTable.vector(element.getLabel());
for (int i = 0; i < vec.length(); i++) { for (int i = 0; i < vec.length(); i++) {
builder.append(vec.getDouble(i)); builder.append(vec.getDouble(i));
@ -518,7 +530,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int code : word.getCodes()) { for (int code : word.getCodes()) {
builder.append(code).append(" "); builder.append(code).append(" ");
} }
@ -536,7 +548,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int point : word.getPoints()) { for (int point : word.getPoints()) {
builder.append(point).append(" "); builder.append(point).append(" ");
} }
@ -554,7 +566,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ") StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ")
.append(word.getElementFrequency()).append(" ") .append(word.getElementFrequency()).append(" ")
.append(vectors.getVocab().docAppearedIn(word.getLabel())); .append(vectors.getVocab().docAppearedIn(word.getLabel()));
@ -638,7 +650,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int code : word.getCodes()) { for (int code : word.getCodes()) {
builder.append(code).append(" "); builder.append(code).append(" ");
} }
@ -656,7 +668,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int point : word.getPoints()) { for (int point : word.getPoints()) {
builder.append(point).append(" "); builder.append(point).append(" ");
} }
@ -677,7 +689,7 @@ public class WordVectorSerializer {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
for (VocabWord word : vectors.getVocab().tokens()) { for (VocabWord word : vectors.getVocab().tokens()) {
if (word.isLabel()) if (word.isLabel())
builder.append(encodeB64(word.getLabel())).append("\n"); builder.append(ReadHelper.encodeB64(word.getLabel())).append("\n");
} }
IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8); IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8);
@ -688,7 +700,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()) builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency())
.append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel())); .append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel()));
writer.println(builder.toString().trim()); writer.println(builder.toString().trim());
@ -744,7 +756,7 @@ public class WordVectorSerializer {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) { try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim())); VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim()));
if (word != null) { if (word != null) {
word.markAsLabel(true); word.markAsLabel(true);
} }
@ -836,7 +848,7 @@ public class WordVectorSerializer {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
String[] split = line.split(" "); String[] split = line.split(" ");
VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0])); VocabWord word = w2v.getVocab().tokenFor(ReadHelper.decodeB64(split[0]));
word.setElementFrequency((long) Double.parseDouble(split[1])); word.setElementFrequency((long) Double.parseDouble(split[1]));
word.setSequencesCount((long) Double.parseDouble(split[2])); word.setSequencesCount((long) Double.parseDouble(split[2]));
} }
@ -946,7 +958,7 @@ public class WordVectorSerializer {
reader = new BufferedReader(new FileReader(h_points)); reader = new BufferedReader(new FileReader(h_points));
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
String[] split = line.split(" "); String[] split = line.split(" ");
VocabWord word = vocab.wordFor(decodeB64(split[0])); VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
List<Integer> points = new ArrayList<>(); List<Integer> points = new ArrayList<>();
for (int i = 1; i < split.length; i++) { for (int i = 1; i < split.length; i++) {
points.add(Integer.parseInt(split[i])); points.add(Integer.parseInt(split[i]));
@ -960,7 +972,7 @@ public class WordVectorSerializer {
reader = new BufferedReader(new FileReader(h_codes)); reader = new BufferedReader(new FileReader(h_codes));
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
String[] split = line.split(" "); String[] split = line.split(" ");
VocabWord word = vocab.wordFor(decodeB64(split[0])); VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
List<Byte> codes = new ArrayList<>(); List<Byte> codes = new ArrayList<>();
for (int i = 1; i < split.length; i++) { for (int i = 1; i < split.length; i++) {
codes.add(Byte.parseByte(split[i])); codes.add(Byte.parseByte(split[i]));
@ -1704,7 +1716,7 @@ public class WordVectorSerializer {
if (line.isEmpty()) if (line.isEmpty())
line = iter.nextLine(); line = iter.nextLine();
String[] split = line.split(" "); String[] split = line.split(" ");
String word = decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " "); String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
VocabWord word1 = new VocabWord(1.0, word); VocabWord word1 = new VocabWord(1.0, word);
word1.setIndex(cache.numWords()); word1.setIndex(cache.numWords());
@ -1994,7 +2006,13 @@ public class WordVectorSerializer {
private static final String SYN1_ENTRY = "syn1.bin"; private static final String SYN1_ENTRY = "syn1.bin";
private static final String SYN1_NEG_ENTRY = "syn1neg.bin"; private static final String SYN1_NEG_ENTRY = "syn1neg.bin";
/**
* This method saves specified SequenceVectors model to target OutputStream
*
* @param vectors SequenceVectors model
* @param stream Target output stream
* @param <T>
*/
public static <T extends SequenceElement> void writeSequenceVectors(@NonNull SequenceVectors<T> vectors, public static <T extends SequenceElement> void writeSequenceVectors(@NonNull SequenceVectors<T> vectors,
@NonNull OutputStream stream) @NonNull OutputStream stream)
throws IOException { throws IOException {
@ -2040,7 +2058,13 @@ public class WordVectorSerializer {
} }
} }
/**
* This method loads SequenceVectors from specified file path
*
* @param path String
* @param readExtendedTables boolean
* @param <T>
*/
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull String path, public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull String path,
boolean readExtendedTables) boolean readExtendedTables)
throws IOException { throws IOException {
@ -2050,6 +2074,14 @@ public class WordVectorSerializer {
return vectors; return vectors;
} }
/**
* This method loads SequenceVectors from specified file path
*
* @param file File
* @param readExtendedTables boolean
* @param <T>
*/
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull File file, public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull File file,
boolean readExtendedTables) boolean readExtendedTables)
throws IOException { throws IOException {
@ -2058,6 +2090,13 @@ public class WordVectorSerializer {
return vectors; return vectors;
} }
/**
* This method loads SequenceVectors from specified input stream
*
* @param stream InputStream
* @param readExtendedTables boolean
* @param <T>
*/
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull InputStream stream, public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull InputStream stream,
boolean readExtendedTables) boolean readExtendedTables)
throws IOException { throws IOException {
@ -2381,6 +2420,12 @@ public class WordVectorSerializer {
} }
} }
/**
* This method loads Word2Vec model from binary file
*
* @param file File
* @return Word2Vec
*/
public static Word2Vec readAsBinary(@NonNull File file) { public static Word2Vec readAsBinary(@NonNull File file) {
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive(); boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency(); int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
@ -2403,6 +2448,12 @@ public class WordVectorSerializer {
} }
} }
/**
* This method loads Word2Vec model from csv file
*
* @param file File
* @return Word2Vec
*/
public static Word2Vec readAsCsv(@NonNull File file) { public static Word2Vec readAsCsv(@NonNull File file) {
Word2Vec vec; Word2Vec vec;
@ -2491,7 +2542,7 @@ public class WordVectorSerializer {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
String[] split = line.split(" "); String[] split = line.split(" ");
VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0])); VocabWord word = new VocabWord(Double.valueOf(split[1]), ReadHelper.decodeB64(split[0]));
word.setIndex(cnt.getAndIncrement()); word.setIndex(cnt.getAndIncrement());
word.incrementSequencesCount(Long.valueOf(split[2])); word.incrementSequencesCount(Long.valueOf(split[2]));
@ -2669,7 +2720,7 @@ public class WordVectorSerializer {
* *
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment. * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
* *
* @param file File should point to previously saved w2v model * @param inputStream InputStream should point to previously saved w2v model
* @return * @return
*/ */
public static WordVectors loadStaticModel(InputStream inputStream) throws IOException { public static WordVectors loadStaticModel(InputStream inputStream) throws IOException {
@ -2685,6 +2736,17 @@ public class WordVectorSerializer {
} }
// TODO: this method needs better name :) // TODO: this method needs better name :)
/**
* This method restores previously saved w2v model. File can be in one of the following formats:
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
*
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
*
* @param file File
* @return
*/
public static WordVectors loadStaticModel(@NonNull File file) { public static WordVectors loadStaticModel(@NonNull File file) {
if (!file.exists() || file.isDirectory()) if (!file.exists() || file.isDirectory())
throw new RuntimeException( throw new RuntimeException(
@ -2843,8 +2905,8 @@ public class WordVectorSerializer {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
try { try {
numWords = Integer.parseInt(readString(stream)); numWords = Integer.parseInt(ReadHelper.readString(stream));
vectorLength = Integer.parseInt(readString(stream)); vectorLength = Integer.parseInt(ReadHelper.readString(stream));
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -2858,13 +2920,13 @@ public class WordVectorSerializer {
@Override @Override
public Pair<VocabWord, float[]> next() { public Pair<VocabWord, float[]> next() {
try { try {
String word = readString(stream); String word = ReadHelper.readString(stream);
VocabWord element = new VocabWord(1.0, word); VocabWord element = new VocabWord(1.0, word);
element.setIndex(idxCounter.getAndIncrement()); element.setIndex(idxCounter.getAndIncrement());
float[] vector = new float[vectorLength]; float[] vector = new float[vectorLength];
for (int i = 0; i < vectorLength; i++) { for (int i = 0; i < vectorLength; i++) {
vector[i] = readFloat(stream); vector[i] = ReadHelper.readFloat(stream);
} }
return Pair.makePair(element, vector); return Pair.makePair(element, vector);
@ -2913,7 +2975,7 @@ public class WordVectorSerializer {
String[] split = nextLine.split(" "); String[] split = nextLine.split(" ");
VocabWord word = new VocabWord(1.0, decodeB64(split[0])); VocabWord word = new VocabWord(1.0, ReadHelper.decodeB64(split[0]));
word.setIndex(idxCounter.getAndIncrement()); word.setIndex(idxCounter.getAndIncrement());
float[] vector = new float[split.length - 1]; float[] vector = new float[split.length - 1];
@ -2937,26 +2999,12 @@ public class WordVectorSerializer {
} }
} }
public static String encodeB64(String word) { /**
try { * This method saves Word2Vec model to output stream
return "B64:" + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", ""); *
} catch (Exception e) { * @param word2Vec Word2Vec
throw new RuntimeException(e); * @param stream OutputStream
} */
}
public static String decodeB64(String word) {
if (word.startsWith("B64:")) {
String arp = word.replaceFirst("B64:", "");
try {
return new String(Base64.decodeBase64(arp), "UTF-8");
} catch (Exception e) {
throw new RuntimeException(e);
}
} else
return word;
}
public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream) public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream)
throws IOException { throws IOException {
@ -2968,6 +3016,13 @@ public class WordVectorSerializer {
writeSequenceVectors(vectors, stream); writeSequenceVectors(vectors, stream);
} }
/**
* This method restores Word2Vec model from file
*
* @param path String
* @param readExtendedTables booleab
* @return Word2Vec
*/
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
throws IOException { throws IOException {
@ -2976,6 +3031,12 @@ public class WordVectorSerializer {
return word2Vec; return word2Vec;
} }
/**
* This method saves table of weights to file
*
* @param weightLookupTable WeightLookupTable
* @param file File
*/
public static <T extends SequenceElement> void writeLookupTable(WeightLookupTable<T> weightLookupTable, public static <T extends SequenceElement> void writeLookupTable(WeightLookupTable<T> weightLookupTable,
@NonNull File file) throws IOException { @NonNull File file) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file),
@ -3038,7 +3099,7 @@ public class WordVectorSerializer {
headerRead = true; headerRead = true;
weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build(); weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build();
} else { } else {
String label = decodeB64(tokens[0]); String label = ReadHelper.decodeB64(tokens[0]);
int freq = Integer.parseInt(tokens[1]); int freq = Integer.parseInt(tokens[1]);
int rows = Integer.parseInt(tokens[2]); int rows = Integer.parseInt(tokens[2]);
int cols = Integer.parseInt(tokens[3]); int cols = Integer.parseInt(tokens[3]);
@ -3071,6 +3132,13 @@ public class WordVectorSerializer {
return weightLookupTable; return weightLookupTable;
} }
/**
* This method loads Word2Vec model from file
*
* @param file File
* @param readExtendedTables boolean
* @return Word2Vec
*/
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
throws IOException { throws IOException {
@ -3078,6 +3146,13 @@ public class WordVectorSerializer {
return word2Vec; return word2Vec;
} }
/**
* This method loads Word2Vec model from input stream
*
* @param stream InputStream
* @param readExtendedTable boolean
* @return Word2Vec
*/
public static Word2Vec readWord2Vec(@NonNull InputStream stream, public static Word2Vec readWord2Vec(@NonNull InputStream stream,
boolean readExtendedTable) throws IOException { boolean readExtendedTable) throws IOException {
SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable); SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable);
@ -3088,6 +3163,12 @@ public class WordVectorSerializer {
return word2Vec; return word2Vec;
} }
/**
* This method loads FastText model to file
*
* @param vectors FastText
* @param path File
*/
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException { public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
ObjectOutputStream outputStream = null; ObjectOutputStream outputStream = null;
try { try {
@ -3106,6 +3187,11 @@ public class WordVectorSerializer {
} }
} }
/**
* This method unloads FastText model from file
*
* @param path File
*/
public static FastText readWordVectors(File path) { public static FastText readWordVectors(File path) {
FastText result = null; FastText result = null;
try { try {
@ -3124,6 +3210,13 @@ public class WordVectorSerializer {
return result; return result;
} }
/**
* This method prints memory usage to log
*
* @param numWords
* @param vectorLength
* @param numTables
*/
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) { public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables; double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;
@ -3144,4 +3237,102 @@ public class WordVectorSerializer {
OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx); OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx);
} }
/**
* Helper static methods to read data from input stream.
*/
public static class ReadHelper {
/**
* Read a float from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param is
* @return
* @throws IOException
*/
private static float readFloat(InputStream is) throws IOException {
byte[] bytes = new byte[4];
is.read(bytes);
return getFloat(bytes);
}
/**
* Read a string from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param b
* @return
* @throws IOException
*/
private static float getFloat(byte[] b) {
int accum = 0;
accum = accum | (b[0] & 0xff) << 0;
accum = accum | (b[1] & 0xff) << 8;
accum = accum | (b[2] & 0xff) << 16;
accum = accum | (b[3] & 0xff) << 24;
return Float.intBitsToFloat(accum);
}
/**
* Read a string from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param dis
* @return
* @throws IOException
*/
private static String readString(DataInputStream dis) throws IOException {
byte[] bytes = new byte[MAX_SIZE];
byte b = dis.readByte();
int i = -1;
StringBuilder sb = new StringBuilder();
while (b != 32 && b != 10) {
i++;
bytes[i] = b;
b = dis.readByte();
if (i == 49) {
sb.append(new String(bytes, "UTF-8"));
i = -1;
bytes = new byte[MAX_SIZE];
}
}
sb.append(new String(bytes, 0, i + 1, "UTF-8"));
return sb.toString();
}
private static final String B64 = "B64:";
/**
* Encode input string
*
* @param word String
* @return String
*/
public static String encodeB64(String word) {
try {
return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", "");
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* Encode input string
*
* @param word String
* @return String
*/
public static String decodeB64(String word) {
if (word.startsWith(B64)) {
String arp = word.replaceFirst(B64, "");
try {
return new String(Base64.decodeBase64(arp), "UTF-8");
} catch (Exception e) {
throw new RuntimeException(e);
}
} else
return word;
}
}
} }

View File

@ -24,7 +24,7 @@
<artifactId>deeplearning4j-aws_2.11</artifactId> <artifactId>deeplearning4j-aws_2.11</artifactId>
<name>DeepLearning4j-AWS</name> <name>DeepLearning4j-AWS</name>
<version>1.0.0_spark_2-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
<build> <build>
<plugins> <plugins>

View File

@ -18,7 +18,7 @@
<parent> <parent>
<artifactId>spark_2.11</artifactId> <artifactId>spark_2.11</artifactId>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<version>1.0.0_spark_2-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>

View File

@ -18,7 +18,7 @@
<parent> <parent>
<artifactId>spark_2.11</artifactId> <artifactId>spark_2.11</artifactId>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<version>1.0.0_spark_2-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>dl4j-spark-nlp_2.11</artifactId> <artifactId>dl4j-spark-nlp_2.11</artifactId>

View File

@ -19,7 +19,7 @@
<parent> <parent>
<artifactId>spark_2.11</artifactId> <artifactId>spark_2.11</artifactId>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<version>1.0.0_spark_2-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>

View File

@ -18,7 +18,7 @@
<parent> <parent>
<artifactId>spark_2.11</artifactId> <artifactId>spark_2.11</artifactId>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<version>1.0.0_spark_2-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>dl4j-spark_2.11</artifactId> <artifactId>dl4j-spark_2.11</artifactId>

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.spark; package org.deeplearning4j.spark;
import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.serializer.SerializerInstance;
import org.deeplearning4j.eval.*;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -28,6 +27,9 @@ import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.junit.Test; import org.junit.Test;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.*;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;

View File

@ -19,7 +19,6 @@ package org.deeplearning4j.spark.impl.multilayer;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -30,6 +29,7 @@ import org.deeplearning4j.spark.BaseSparkTest;
import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.junit.Test; import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

View File

@ -29,15 +29,13 @@ import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.mllib.util.MLUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution; import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
@ -56,6 +54,9 @@ import org.junit.Ignore;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -63,6 +64,7 @@ import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.learning.config.RmsProp;
@ -70,7 +72,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import scala.Tuple2; import scala.Tuple2;
import java.io.File; import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.*; import java.util.*;
@ -121,11 +122,6 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0)); new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
MultiLayerNetwork network2 = master.fitLabeledPoint(data); MultiLayerNetwork network2 = master.fitLabeledPoint(data);
Evaluation evaluation = new Evaluation();
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
System.out.println(evaluation.stats());
} }
@ -137,20 +133,15 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.getAbsolutePath()) .getAbsolutePath())
.toJavaRDD().map(new TestFn()); .toJavaRDD().map(new TestFn());
DataSet d = new IrisDataSetIterator(150, 150).next();
MultiLayerConfiguration conf = MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder().seed(123) new NeuralNetConfiguration.Builder().seed(123)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) .updater(new Adam(1e-6))
.miniBatch(true).maxNumLineSearchIterations(10)
.list().layer(0,
new DenseLayer.Builder().nIn(4).nOut(100)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(Activation.RELU) .list()
.build()) .layer(new BatchNormalization.Builder().nIn(4).nOut(4).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( .layer(new DenseLayer.Builder().nIn(4).nOut(32).activation(Activation.RELU).build())
LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3) .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(32).nOut(3)
.activation(Activation.SOFTMAX) .activation(Activation.SOFTMAX).build())
.weightInit(WeightInit.XAVIER).build())
.build(); .build();
@ -161,10 +152,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(), SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(),
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0)); new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
MultiLayerNetwork network2 = master.fitLabeledPoint(data); master.fitLabeledPoint(data);
Evaluation evaluation = new Evaluation();
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
System.out.println(evaluation.stats());
} }
@Test(timeout = 120000L) @Test(timeout = 120000L)
@ -465,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
tempDirF.deleteOnExit(); tempDirF.deleteOnExit();
int dataSetObjSize = 1; int dataSetObjSize = 1;
int batchSizePerExecutor = 25; int batchSizePerExecutor = 16;
int numSplits = 10; int numSplits = 5;
int averagingFrequency = 3; int averagingFrequency = 3;
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency; int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false); DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);

View File

@ -22,7 +22,7 @@
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>spark_2.11</artifactId> <artifactId>spark_2.11</artifactId>
<version>1.0.0_spark_2-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
<packaging>pom</packaging> <packaging>pom</packaging>
<name>Spark parent</name> <name>Spark parent</name>
@ -36,7 +36,7 @@
<properties> <properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding> <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<datavec.spark.version>1.0.0_spark_2-SNAPSHOT</datavec.spark.version> <datavec.spark.version>1.0.0-SNAPSHOT</datavec.spark.version>
<scala.macros.version>2.1.0</scala.macros.version> <scala.macros.version>2.1.0</scala.macros.version>

View File

@ -24,11 +24,13 @@
#include <map> #include <map>
#include <array/ConstantDescriptor.h> #include <array/ConstantDescriptor.h>
#include <array/ConstantDataBuffer.h> #include <array/ConstantDataBuffer.h>
#include <mutex>
namespace nd4j { namespace nd4j {
class ConstantHolder { class ConstantHolder {
private: private:
int _deviceId = 0; int _deviceId = 0;
std::mutex _mutex;
std::map<nd4j::DataType, ConstantDataBuffer> _buffers; std::map<nd4j::DataType, ConstantDataBuffer> _buffers;
public: public:
@ -53,6 +55,8 @@ namespace nd4j {
template <typename T> template <typename T>
ConstantDataBuffer* getConstantDataBuffer(); ConstantDataBuffer* getConstantDataBuffer();
std::mutex* mutex();
}; };
} }

View File

@ -16,6 +16,10 @@ namespace nd4j {
return _buffers.count(dataType) > 0; return _buffers.count(dataType) > 0;
} }
std::mutex* ConstantHolder::mutex() {
return &_mutex;
}
template <typename T> template <typename T>
bool ConstantHolder::hasBuffer() { bool ConstantHolder::hasBuffer() {
return hasBuffer(DataTypeUtils::fromT<T>()); return hasBuffer(DataTypeUtils::fromT<T>());

View File

@ -47,7 +47,7 @@ namespace nd4j {
_currentMutex.unlock(); _currentMutex.unlock();
setCurrentDevice(globalThreadToDevice); setCurrentNativeDevice(globalThreadToDevice);
} }
// if we already know affinity - just return it // if we already know affinity - just return it
@ -92,6 +92,8 @@ namespace nd4j {
void AffinityManager::setCurrentNativeDevice(int deviceId) { void AffinityManager::setCurrentNativeDevice(int deviceId) {
auto res = cudaSetDevice(deviceId); auto res = cudaSetDevice(deviceId);
if (res != 0)
throw cuda_exception::build("setCurrentDevice failed", res);
} }
void AffinityManager::setCurrentDevice(int deviceId) { void AffinityManager::setCurrentDevice(int deviceId) {
@ -104,17 +106,22 @@ namespace nd4j {
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream()); res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream());
if (res != 0) if (res != 0)
throw cuda_exception::build("setCurrentDevice -> specialSync failed", res); throw cuda_exception::build("setCurrentDevice -> specialSync failed", res);
if (deviceId != previousDeviceId) {
// discard existing stuff
nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", "");
LaunchContext::releaseBuffers();
}
} }
if (deviceId != previousDeviceId) {
auto res = cudaSetDevice(deviceId); auto res = cudaSetDevice(deviceId);
if (res != 0) if (res != 0)
throw cuda_exception::build("cudaSetDevice failed", res); throw cuda_exception::build("cudaSetDevice failed", res);
}
// update thread-device affinity // update thread-device affinity
globalThreadToDevice = deviceId; globalThreadToDevice = deviceId;
// discard existing stuff
LaunchContext::releaseBuffers();
} }
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV); std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);

View File

@ -107,7 +107,6 @@ namespace nd4j {
////// //////
_allocated = false; _allocated = false;
_initialized = false;
_deviceId = -1; _deviceId = -1;
this->_specialStream = nullptr; this->_specialStream = nullptr;
@ -116,6 +115,8 @@ namespace nd4j {
this->_reductionPointer = nullptr; this->_reductionPointer = nullptr;
this->_scalarPointer = nullptr; this->_scalarPointer = nullptr;
} }
_initialized = false;
} }
ContextBuffers::~ContextBuffers() { ContextBuffers::~ContextBuffers() {
@ -163,21 +164,21 @@ namespace nd4j {
} }
void* ContextBuffers::reductionBuffer() { void* ContextBuffers::reductionBuffer() {
if (_reductionPointer == nullptr) if (!_initialized)
initialize(); initialize();
return _reductionPointer; return _reductionPointer;
} }
void* ContextBuffers::scalarBuffer() { void* ContextBuffers::scalarBuffer() {
if (_scalarPointer == nullptr) if (!_initialized)
initialize(); initialize();
return _scalarPointer; return _scalarPointer;
} }
void* ContextBuffers::allocationBuffer() { void* ContextBuffers::allocationBuffer() {
if (_allocationPointer == nullptr) if (!_initialized)
initialize(); initialize();
return _allocationPointer; return _allocationPointer;
@ -204,15 +205,23 @@ namespace nd4j {
} }
void* ContextBuffers::execStream() { void* ContextBuffers::execStream() {
if (_execStream == nullptr) if (!_initialized) {
//nd4j_printf("execStream not initialized\n", "");
initialize(); initialize();
} else {
//nd4j_printf("execStream is initialized\n", "");
}
return _execStream; return _execStream;
} }
void* ContextBuffers::specialStream() { void* ContextBuffers::specialStream() {
if (_specialStream == nullptr) if (!_initialized) {
//nd4j_printf("specialStream not initialized\n", "");
initialize(); initialize();
} else {
//nd4j_printf("specialStream is initialized\n", "");
}
return _specialStream; return _specialStream;
} }

View File

@ -57,10 +57,6 @@ LaunchContext::LaunchContext() {
_deviceID = 0; _deviceID = 0;
_isAllocated = true; _isAllocated = true;
_cublasHandle = CublasHelper::getInstance()->handle();
_cusolverHandle = CublasHelper::getInstance()->solver();
} }
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
@ -89,13 +85,13 @@ LaunchContext::LaunchContext() {
_contexts.resize(numDevices); _contexts.resize(numDevices);
for (int e = 0; e < numDevices; e++) { for (int e = 0; e < numDevices; e++) {
AffinityManager::setCurrentDevice(e); AffinityManager::setCurrentNativeDevice(e);
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>(); LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
} }
// don't forget to restore device back again // don't forget to restore device back again
AffinityManager::setCurrentDevice(deviceId); AffinityManager::setCurrentNativeDevice(deviceId);
} }
_mutex.unlock(); _mutex.unlock();
@ -117,11 +113,11 @@ LaunchContext::LaunchContext() {
}; };
void* LaunchContext::getCublasHandle() const { void* LaunchContext::getCublasHandle() const {
return _cublasHandle; return CublasHelper::getInstance()->handle();
}; };
void* LaunchContext::getCusolverHandle() const { void* LaunchContext::getCusolverHandle() const {
return _cusolverHandle; return CublasHelper::getInstance()->solver();
}; };
cudaStream_t* LaunchContext::getCudaStream() const { cudaStream_t* LaunchContext::getCudaStream() const {
@ -162,6 +158,7 @@ LaunchContext::LaunchContext() {
}; };
void LaunchContext::releaseBuffers() { void LaunchContext::releaseBuffers() {
nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", "");
contextBuffers.release(); contextBuffers.release();
} }

View File

@ -38,12 +38,13 @@ namespace nd4j {
static ConstantHelper* _INSTANCE; static ConstantHelper* _INSTANCE;
ConstantHelper(); ConstantHelper();
std::vector<std::map<ConstantDescriptor, ConstantHolder>> _cache; std::vector<std::map<ConstantDescriptor, ConstantHolder*>> _cache;
// tracking of per-device constant memory buffers (CUDA only atm) // tracking of per-device constant memory buffers (CUDA only atm)
std::vector<Nd4jPointer> _devicePointers; std::vector<Nd4jPointer> _devicePointers;
std::vector<Nd4jLong> _deviceOffsets; std::vector<Nd4jLong> _deviceOffsets;
std::mutex _mutex; std::mutex _mutex;
std::mutex _mutexHolder;
std::vector<Nd4jLong> _counters; std::vector<Nd4jLong> _counters;
public: public:

View File

@ -48,10 +48,10 @@ namespace nd4j {
static ConstantShapeHelper* getInstance(); static ConstantShapeHelper* getInstance();
ConstantDataBuffer& bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape); ConstantDataBuffer bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
ConstantDataBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor); ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor);
ConstantDataBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo); ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo);
ConstantDataBuffer& bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape); ConstantDataBuffer bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType); Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType);

View File

@ -54,11 +54,11 @@ namespace nd4j {
* @param keepUnitiesInShape * @param keepUnitiesInShape
* @return * @return
*/ */
TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false); TadPack tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); TadPack tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false); TadPack tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(TadDescriptor &descriptor); TadPack tadForDimensions(TadDescriptor &descriptor);
/** /**
* This method returns number of cached TAD shapes/offsets on specific device * This method returns number of cached TAD shapes/offsets on specific device

View File

@ -33,7 +33,8 @@ namespace nd4j {
_cache.resize(numDevices); _cache.resize(numDevices);
_counters.resize(numDevices); _counters.resize(numDevices);
for (int e = 0; e < numDevices; e++) { for (int e = 0; e < numDevices; e++) {
std::map<ConstantDescriptor, ConstantHolder> map; std::map<ConstantDescriptor, ConstantHolder*> map;
_cache[e] = map; _cache[e] = map;
_counters[e] = 0L; _counters[e] = 0L;
} }
@ -70,15 +71,26 @@ namespace nd4j {
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
const auto deviceId = getCurrentDevice(); const auto deviceId = getCurrentDevice();
// we're locking away cache modification
_mutexHolder.lock();
if (_cache[deviceId].count(descriptor) == 0) { if (_cache[deviceId].count(descriptor) == 0) {
ConstantHolder holder; _cache[deviceId][descriptor] = new ConstantHolder();
_cache[deviceId][descriptor] = holder;
} }
ConstantHolder* holder = &_cache[deviceId][descriptor]; auto holder = _cache[deviceId][descriptor];
// releasing cache lock
_mutexHolder.unlock();
ConstantDataBuffer* result;
// access to this holder instance is synchronous
holder->mutex()->lock();
if (holder->hasBuffer(dataType)) if (holder->hasBuffer(dataType))
return holder->getConstantDataBuffer(dataType); result = holder->getConstantDataBuffer(dataType);
else { else {
auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType); auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType);
auto cbuff = new int8_t[size]; auto cbuff = new int8_t[size];
@ -94,8 +106,11 @@ namespace nd4j {
ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType)); ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType));
holder->addBuffer(dataBuffer, dataType); holder->addBuffer(dataBuffer, dataType);
return holder->getConstantDataBuffer(dataType); result = holder->getConstantDataBuffer(dataType);
} }
holder->mutex()->unlock();
return result;
} }
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {

View File

@ -41,18 +41,18 @@ namespace nd4j {
return _INSTANCE; return _INSTANCE;
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
ShapeDescriptor descriptor(dataType, order, shape); ShapeDescriptor descriptor(dataType, order, shape);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ShapeDescriptor descriptor(dataType, order, shape, rank); ShapeDescriptor descriptor(dataType, order, shape, rank);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = 0; int deviceId = 0;
_mutex.lock(); _mutex.lock();
@ -62,19 +62,19 @@ namespace nd4j {
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64); ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
ShapeDescriptor descriptor1(descriptor); ShapeDescriptor descriptor1(descriptor);
_cache[deviceId][descriptor1] = buffer; _cache[deviceId][descriptor1] = buffer;
ConstantDataBuffer &r = _cache[deviceId][descriptor1]; auto r = _cache[deviceId][descriptor1];
_mutex.unlock(); _mutex.unlock();
return r; return r;
} else { } else {
ConstantDataBuffer &r = _cache[deviceId].at(descriptor); auto r = _cache[deviceId].at(descriptor);
_mutex.unlock(); _mutex.unlock();
return r; return r;
} }
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
ShapeDescriptor descriptor(shapeInfo); ShapeDescriptor descriptor(shapeInfo);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }

View File

@ -38,25 +38,25 @@ namespace nd4j {
return _INSTANCE; return _INSTANCE;
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
return tadForDimensions(tadDescriptor); return tadForDimensions(tadDescriptor);
} }
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
return tadForDimensions(tadDescriptor); return tadForDimensions(tadDescriptor);
} }
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
const int deviceId = 0; const int deviceId = 0;
_mutex.lock(); _mutex.lock();
@ -105,7 +105,7 @@ namespace nd4j {
return r; return r;
} else { } else {
TadPack &r = _cache[deviceId][descriptor]; TadPack r = _cache[deviceId][descriptor];
_mutex.unlock(); _mutex.unlock();
return r; return r;

View File

@ -24,11 +24,13 @@
#include <dll.h> #include <dll.h>
#include <pointercast.h> #include <pointercast.h>
#include <vector> #include <vector>
#include <mutex>
namespace nd4j { namespace nd4j {
class CublasHelper { class CublasHelper {
private: private:
static CublasHelper *_INSTANCE; static CublasHelper *_INSTANCE;
static std::mutex _mutex;
std::vector<void*> _cache; std::vector<void*> _cache;
std::vector<void*> _solvers; std::vector<void*> _solvers;

View File

@ -68,7 +68,7 @@ namespace nd4j {
throw cuda_exception::build("cudaSetDevice failed", res); throw cuda_exception::build("cudaSetDevice failed", res);
auto constant = getConstantSpace(); auto constant = getConstantSpace();
std::map<ConstantDescriptor, ConstantHolder> devCache; std::map<ConstantDescriptor, ConstantHolder*> devCache;
_devicePointers[e] = constant; _devicePointers[e] = constant;
_deviceOffsets[e] = 0; _deviceOffsets[e] = 0;
@ -136,15 +136,24 @@ namespace nd4j {
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
const auto deviceId = getCurrentDevice(); const auto deviceId = getCurrentDevice();
if (_cache[deviceId].count(descriptor) == 0) { // all cache modifications are synchronous
ConstantHolder holder; _mutexHolder.lock();
_cache[deviceId][descriptor] = holder;
}
ConstantHolder* holder = &_cache[deviceId][descriptor]; if (_cache[deviceId].count(descriptor) == 0) {
_cache[deviceId][descriptor] = new ConstantHolder();
}
auto holder = _cache[deviceId][descriptor];
// release cache lock
_mutexHolder.unlock();
ConstantDataBuffer* result;
// access to this holder instance is synchronous
holder->mutex()->lock();
if (holder->hasBuffer(dataType)) { if (holder->hasBuffer(dataType)) {
return holder->getConstantDataBuffer(dataType); result = holder->getConstantDataBuffer(dataType);
} else { } else {
auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType); auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType);
auto cbuff = new int8_t[numBytes]; auto cbuff = new int8_t[numBytes];
@ -160,10 +169,14 @@ namespace nd4j {
auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType)); auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType));
ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType)); ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType));
holder->addBuffer(dataBuffer, dataType);
return holder->getConstantDataBuffer(dataType); holder->addBuffer(dataBuffer, dataType);
result = holder->getConstantDataBuffer(dataType);
} }
// release holder lock
holder->mutex()->unlock();
return result;
} }
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {

View File

@ -44,17 +44,17 @@ namespace nd4j {
return _INSTANCE; return _INSTANCE;
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
ShapeDescriptor descriptor(dataType, order, shape); ShapeDescriptor descriptor(dataType, order, shape);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ShapeDescriptor descriptor(dataType, order, shape, rank); ShapeDescriptor descriptor(dataType, order, shape, rank);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = AffinityManager::currentDeviceId(); int deviceId = AffinityManager::currentDeviceId();
_mutex.lock(); _mutex.lock();
@ -65,19 +65,19 @@ namespace nd4j {
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64); ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
ShapeDescriptor descriptor1(descriptor); ShapeDescriptor descriptor1(descriptor);
_cache[deviceId][descriptor1] = buffer; _cache[deviceId][descriptor1] = buffer;
ConstantDataBuffer &r = _cache[deviceId][descriptor1]; auto r = _cache[deviceId][descriptor1];
_mutex.unlock(); _mutex.unlock();
return r; return r;
} else { } else {
ConstantDataBuffer &r = _cache[deviceId].at(descriptor); ConstantDataBuffer r = _cache[deviceId].at(descriptor);
_mutex.unlock(); _mutex.unlock();
return r; return r;
} }
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
ShapeDescriptor descriptor(shapeInfo); ShapeDescriptor descriptor(shapeInfo);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }

View File

@ -43,25 +43,25 @@ namespace nd4j {
return _INSTANCE; return _INSTANCE;
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
return tadForDimensions(tadDescriptor); return tadForDimensions(tadDescriptor);
} }
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
return tadForDimensions(tadDescriptor); return tadForDimensions(tadDescriptor);
} }
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
const int deviceId = AffinityManager::currentDeviceId(); const int deviceId = AffinityManager::currentDeviceId();
_mutex.lock(); _mutex.lock();
@ -96,14 +96,14 @@ namespace nd4j {
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
_cache[deviceId][descriptor] = t; _cache[deviceId][descriptor] = t;
TadPack &r = _cache[deviceId][descriptor]; TadPack r = _cache[deviceId][descriptor];
_mutex.unlock(); _mutex.unlock();
delete[] shapeInfo; delete[] shapeInfo;
return r; return r;
} else { } else {
TadPack &r = _cache[deviceId][descriptor]; TadPack r = _cache[deviceId][descriptor];
_mutex.unlock(); _mutex.unlock();
return r; return r;

View File

@ -27,6 +27,7 @@
#include <execution/AffinityManager.h> #include <execution/AffinityManager.h>
namespace nd4j { namespace nd4j {
std::mutex CublasHelper::_mutex;
static void* handle_() { static void* handle_() {
auto _handle = new cublasHandle_t(); auto _handle = new cublasHandle_t();
@ -56,22 +57,24 @@ namespace nd4j {
} }
CublasHelper::CublasHelper() { CublasHelper::CublasHelper() {
//nd4j_printf("Initializing cuBLAS\n","");
auto numDevices = AffinityManager::numberOfDevices(); auto numDevices = AffinityManager::numberOfDevices();
auto currentDevice = AffinityManager::currentDeviceId(); auto currentDevice = AffinityManager::currentDeviceId();
_cache.resize(numDevices); _cache.resize(numDevices);
_solvers.resize(numDevices); _solvers.resize(numDevices);
for (int e = 0; e < numDevices; e++) { for (int e = 0; e < numDevices; e++) {
AffinityManager::setCurrentDevice(e); AffinityManager::setCurrentNativeDevice(e);
_cache[e] = handle_(); _cache[e] = handle_();
_solvers[e] = solver_(); _solvers[e] = solver_();
} }
// don't forget to restore back original device // don't forget to restore back original device
AffinityManager::setCurrentDevice(currentDevice); AffinityManager::setCurrentNativeDevice(currentDevice);
} }
CublasHelper::~CublasHelper() { CublasHelper::~CublasHelper() {
nd4j_printf("Releasing cuBLAS\n","");
auto numDevices = AffinityManager::numberOfDevices(); auto numDevices = AffinityManager::numberOfDevices();
for (int e = 0; e < numDevices; e++) for (int e = 0; e < numDevices; e++)
@ -79,8 +82,10 @@ namespace nd4j {
} }
CublasHelper* CublasHelper::getInstance() { CublasHelper* CublasHelper::getInstance() {
_mutex.lock();
if (!_INSTANCE) if (!_INSTANCE)
_INSTANCE = new nd4j::CublasHelper(); _INSTANCE = new nd4j::CublasHelper();
_mutex.unlock();
return _INSTANCE; return _INSTANCE;
} }

View File

@ -162,9 +162,6 @@ namespace functions {
} }
} }
} }
// update rng state
rng->rewindH(length);
}; };
@ -223,8 +220,6 @@ namespace functions {
} }
} }
} }
// update rng state
rng->rewindH(length);
} }
@ -256,9 +251,6 @@ namespace functions {
z[offset] = OpClass::op(i+threadOffset, length, rng, extraArguments); z[offset] = OpClass::op(i+threadOffset, length, rng, extraArguments);
} }
} }
// update rng state
rng->rewindH(length);
} }
template<typename X> template<typename X>

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
// //
// @author Yurii Shyrma (iuriish@yahoo.com), created on 07.12.2017 // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <op_boilerplate.h> #include <op_boilerplate.h>
@ -38,10 +38,9 @@ CONFIGURABLE_OP_IMPL(matrix_set_diag, 2, 1, false, 0, 0) {
for(int i = 0; i < diagonal->rankOf() - 1; ++i) for(int i = 0; i < diagonal->rankOf() - 1; ++i)
REQUIRE_TRUE(diagonal->sizeAt(i) == input->sizeAt(i), 0, "MATRIX_SET_DIAG op: the shapes of diagonal and input arrays must be equal till last diagonal dimension but one, however got diagonal=%s and input=%s instead !", ShapeUtils::shapeAsString(diagonal).c_str(), ShapeUtils::shapeAsString(input).c_str()); REQUIRE_TRUE(diagonal->sizeAt(i) == input->sizeAt(i), 0, "MATRIX_SET_DIAG op: the shapes of diagonal and input arrays must be equal till last diagonal dimension but one, however got diagonal=%s and input=%s instead !", ShapeUtils::shapeAsString(diagonal).c_str(), ShapeUtils::shapeAsString(input).c_str());
REQUIRE_TRUE(diagonal->sizeAt(-1) == (int)nd4j::math::nd4j_min<Nd4jLong>(input->sizeAt(-1), input->sizeAt(-2)), REQUIRE_TRUE(diagonal->sizeAt(-1) == (int)nd4j::math::nd4j_min<Nd4jLong>(input->sizeAt(-1), input->sizeAt(-2)), 0, "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but got %i instead !", input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1));
0, "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but got %i instead !", input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1));
helpers::matrixSetDiag(block.launchContext(), input, diagonal, output); helpers::matrixSetDiag(block.launchContext(), *input, *diagonal, *output, false);
return Status::OK(); return Status::OK();
} }

View File

@ -15,26 +15,30 @@
******************************************************************************/ ******************************************************************************/
// //
// Created to use with batched tensor by GS <sgazeos@gmail.com> 3/21/2018 // @author GS <sgazeos@gmail.com> 3/21/2018
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/matrix_diag.h> #include <ops/declarable/helpers/matrixSetDiag.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(matrix_diag, 1, 1, false, 0, 0) { CUSTOM_OP_IMPL(matrix_diag, 1, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto diagonal = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(!input->isScalar(), 0, "CUSTOM_OP matrix_diag: input array must be at list a vector, but scalar was given!"); REQUIRE_TRUE(!diagonal->isScalar(), 0, "CUSTOM_OP matrix_diag: input diagonal array must be at list a vector, but scalar was given!");
output->nullify(); helpers::matrixSetDiag(block.launchContext(), *output, *diagonal, *output, true);
return helpers::matrixDiag(block.launchContext(), input, output);
return Status::OK();
} }
DECLARE_SHAPE_FN(matrix_diag) { DECLARE_SHAPE_FN(matrix_diag) {
Nd4jLong* outShapeInfo = nullptr; Nd4jLong* outShapeInfo = nullptr;
auto in = inputShape->at(0); auto in = inputShape->at(0);
int inRank = shape::rank(in); int inRank = shape::rank(in);

View File

@ -76,7 +76,19 @@ namespace nd4j {
#endif #endif
/** /**
* Returns a batched matrix tensor with new batched diagonal values. * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
*
* Input arrays:
* input: input array, considered as batch of matrices
* diagonal: array containing elements to be inserted into input array,
* following rank condition should be satisfied: diagonal_rank = input_rank - 1,
* the shapes of diagonal and input arrays must be equal except last dimension of input array,
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
*
* Output array:
* has the same shape as input, corresponding diagonal elements are substituted
*/ */
#if NOT_EXCLUDED(OP_matrix_set_diag) #if NOT_EXCLUDED(OP_matrix_set_diag)
DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0); DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0);

View File

@ -2411,7 +2411,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) for (Nd4jLong kd = dstart; kd < dend; kd += iStep2)
for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) for (Nd4jLong kh = hstart; kh < hend; kh += iStep3)
for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) for (Nd4jLong kw = wstart; kw < wend; kw += iStep4)
pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kd + kh + kw]), extraParam0 - (T)1.f); pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kd + kh + kw]), extraParam0 - (T)1.f) * nd4j::math::nd4j_sgn<T,T>(pIn[kd + kh + kw]);
} }
else { else {

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
// //
// Created by Yurii Shyrma on 07.12.2017. // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include "ResultSet.h" #include "ResultSet.h"
@ -27,32 +27,49 @@ namespace helpers {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// Returns a batched matrix tensor with new batched diagonal values.
// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag
template<typename T> template<typename T>
static void _matrixSetDiag(const NDArray* input, const NDArray* diagonal, NDArray* output) { void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) {
*output = *input; // input and output are the same array (x == z) when zeroPad = true
// xRank = zRank, xRank = yRank + 1
// xLen = zLen
const int lastDimSize = input->sizeAt(-1); const T* x = input.bufferAsT<T>();
const int last2DimSize = input->sizeAt(-1) * input->sizeAt(-2); const T* y = diagonal.bufferAsT<T>();
const int lastSmallDim = diagonal->sizeAt(-1); T* z = output.bufferAsT<T>();
const int batchSize = input->lengthOf()/last2DimSize;
for(int i = 0; i < batchSize; ++i ) const Nd4jLong* xShapeInfo = input.getShapeInfo();
for(int j = 0; j < lastSmallDim; ++j) { const Nd4jLong* yShapeInfo = diagonal.getShapeInfo();
output->p(i*last2DimSize + j*(lastDimSize + 1), diagonal->e<T>(i*lastSmallDim + j)); const Nd4jLong* zShapeInfo = output.getShapeInfo();
const bool areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not
const int xRank = input.rankOf();
const auto xLen = input.lengthOf();
std::vector<Nd4jLong> coords(xRank); // we use the same coordinates storage both for input and output since their ranks are the same
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(coords))
for (Nd4jLong i = 0; i < xLen; ++i) {
shape::index2coords(xRank, xShapeInfo + 1, i, xLen, coords.data());
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords.data(), xRank);
const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords.data(), xRank);
// condition to be on diagonal of innermost matrix
if(coords[xRank - 2] == coords[xRank - 1])
z[zOffset] = y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank, coords.data(), xRank - 1)];
else
z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset];
}
} }
//////////////////////////////////////////////////////////////////////////
void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) {
BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiag_, (input, diagonal, output, zeroPad), LIBND4J_TYPES);
} }
void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), _matrixSetDiag, (input, diagonal, output), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void _matrixSetDiag, (const NDArray* input, const NDArray* diagonal, NDArray* output), LIBND4J_TYPES);
} }
} }
} }

View File

@ -1,65 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by GS <sgazeos@gmail.com> on 3/21/2018.
//
#include "ResultSet.h"
#include <ops/declarable/helpers/matrix_diag.h>
#include <Status.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
// Returns a batched matrix tensor with new batched diagonal values.
// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag
template <typename T>
static int _matrixDiag(const NDArray* input, NDArray* output) {
auto listOut = output->allTensorsAlongDimension({output->rankOf() - 2, output->rankOf() - 1});
auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 1});
if (listOut->size() != listDiag->size()) {
nd4j_printf("matrix_diag: Input matrix has wrong shape.", "");
return ND4J_STATUS_VALIDATION;
}
int lastDimension = input->sizeAt(-1);
// TODO: tune this properlys
int lO = listOut->size();
PRAGMA_OMP_PARALLEL_FOR_IF(lO > Environment::getInstance()->tadThreshold())
for(int i = 0; i < lO; ++i)
for (int e = 0; e < lastDimension; e++)
listOut->at(i)->p(e, e, listDiag->at(i)->e<T>(e));
delete listOut;
delete listDiag;
return Status::OK();
}
int matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiag, (input, output), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template int _matrixDiag, (const NDArray* input, NDArray* output), LIBND4J_TYPES);
}
}
}

View File

@ -957,9 +957,13 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
val *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1.f - extraParam0) / extraParam0); val *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1.f - extraParam0) / extraParam0);
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {
for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) {
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], val * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)]), extraParam0 - 1.f)); const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[zOffset], val * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(x[xOffset]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn<T,T>(x[xOffset]));
}
}
} }
break; break;
} }
@ -1123,10 +1127,15 @@ __global__ static void pooling3dBPCuda(const void* vx, const Nd4jLong* xShapeInf
val *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1.f - extraParam0) / extraParam0); val *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1.f - extraParam0) / extraParam0);
for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) {
for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) {
for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) {
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], val * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)]), extraParam0 - 1.f)); const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[zOffset], val * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(x[xOffset]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn<T,T>(x[xOffset]));
}
}
}
} }
break; break;
} }

View File

@ -15,63 +15,87 @@
******************************************************************************/ ******************************************************************************/
// //
// Created by Yurii Shyrma on 07.12.2017. // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include "ResultSet.h" #include "ResultSet.h"
#include <ops/declarable/helpers/matrixSetDiag.h> #include <ops/declarable/helpers/matrixSetDiag.h>
#include <PointersManager.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
static __global__ void matrixSetDiagKernel(void* outputBuffer, Nd4jLong* outputShape, void const* diagonalBuffer, Nd4jLong* diagonalShape, Nd4jLong lastDimSize, Nd4jLong last2DimSize, Nd4jLong lastSmallDim, Nd4jLong batchSize) { __global__ static void matrixSetDiagCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) {
__shared__ T* z;
__shared__ T const* x; // x - input, shape [A,B,C]
__shared__ Nd4jLong outLength, diagonalLen; // y - diagonal, shape [A,B]
// z - output, shape [A,B,C]
// input and output are the same array (x == z) when zeroPad = true
const auto x = reinterpret_cast<const T*>(vx);
const auto y = reinterpret_cast<const T*>(vy);
auto z = reinterpret_cast<T*>(vz);
__shared__ int xRank; // xRank = zRank, xRank = yRank + 1
__shared__ Nd4jLong xLen, *sharedMem; // xLen = zLen
__shared__ bool areSameOffsets;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
z = reinterpret_cast<T*>(outputBuffer);
x = reinterpret_cast<T const*>(diagonalBuffer); extern __shared__ unsigned char shmem[];
outLength = shape::length(outputShape); sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
diagonalLen = shape::length(diagonalShape);
areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not
xRank = shape::rank(xShapeInfo);
xLen = shape::length(xShapeInfo);
} }
__syncthreads(); __syncthreads();
for(int i = blockIdx.x; i < batchSize; i+= gridDim.x ) auto coords = sharedMem + threadIdx.x * xRank; // we provide (xRank * sizeof(Nd4jLong) * threadIdx.x) amount of shared memory per each thread
for(int j = threadIdx.x; j < lastSmallDim; j += blockDim.x) { const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
// z[i * last2DimSize + j * (lastDimSize + 1)] = x[i * lastSmallDim + j];
z[shape::getIndexOffset(i * last2DimSize + j * (lastDimSize + 1), outputShape, outLength)] = x[shape::getIndexOffset(i * lastSmallDim + j, diagonalShape, diagonalLen)]; for (Nd4jLong i = tid; i < xLen; i += gridDim.x * blockDim.x) {
shape::index2coords(xRank, xShapeInfo + 1, i, xLen, coords);
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords, xRank);
const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords, xRank);
// condition to be on diagonal of innermost matrix
if(coords[xRank - 2] == coords[xRank - 1])
z[zOffset] = y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank, coords, xRank - 1)];
else
z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset];
} }
} }
//////////////////////////////////////////////////////////////////////////
// Returns a batched matrix tensor with new batched diagonal values. ///////////////////////////////////////////////////////////////////
// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag
template<typename T> template<typename T>
static void _matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) { static void matrixSetDiagCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) {
*output = *input;
const int lastDimSize = input->sizeAt(-1);
const int last2DimSize = input->sizeAt(-1) * input->sizeAt(-2);
const int lastSmallDim = diagonal->sizeAt(-1);
const int batchSize = input->lengthOf()/last2DimSize;
auto stream = context->getCudaStream();
dim3 launchDims(256, 512, 8192);
matrixSetDiagKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), diagonal->getSpecialBuffer(), diagonal->getSpecialShapeInfo(), lastDimSize, last2DimSize, lastSmallDim, batchSize);
//// #pragma omp parallel for if(batchSize > Environment::getInstance()->elementwiseThreshold()) schedule(static)
// for(int i = 0; i < batchSize; ++i )
// for(int j = 0; j < lastSmallDim; ++j) {
// output->p(i*last2DimSize + j*(lastDimSize + 1), diagonal->e<T>(i*lastSmallDim + j));
// }
matrixSetDiagCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, zeroPad);
} }
void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) { ///////////////////////////////////////////////////////////////////
BUILD_SINGLE_SELECTOR(input->dataType(), _matrixSetDiag, (context, input, diagonal, output), LIBND4J_TYPES); void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) {
}
BUILD_SINGLE_TEMPLATE(template void _matrixSetDiag, (nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output), LIBND4J_TYPES); const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * input.rankOf() + 128;
PointersManager manager(context, "matrixSetDiag");
NDArray::prepareSpecialUse({&output}, {&input, &diagonal});
BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiagCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), diagonal.getSpecialBuffer(), diagonal.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), zeroPad), LIBND4J_TYPES);
NDArray::registerSpecialUse({&output}, {&input, &diagonal});
manager.synchronize();
}
} }
} }

View File

@ -1,95 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by GS <sgazeos@gmail.com> on 3/21/2018.
//
#include "ResultSet.h"
#include <ops/declarable/helpers/matrix_diag.h>
#include <Status.h>
#include <ShapeUtils.h>
#include <ShapeUtils.h>
#include <TAD.h>
#include <cuda_exception.h>
#include <helpers/ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static __global__ void matrixDiagKernel(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets) {
int totalThreads = blockDim.x;
for (Nd4jLong i = blockIdx.x; i < numTads; i += gridDim.x) {
auto yOffset = tadInputOffsets[i];
auto xOffset = tadOutputOffsets[i];
for (Nd4jLong j = threadIdx.x; j < inputLength; j += totalThreads) {
Nd4jLong coords[2] = {j, j};
Nd4jLong tadOffset = shape::getOffset(0, shape::shapeOf(tadOnlyOutputShapeInfo), shape::stride(tadOnlyOutputShapeInfo), coords, 2);
//shape::getIndexOffset(j, tadOnlyOutputShapeInfo, inputLength)
*(reinterpret_cast<T*>(outputBuffer) + xOffset + tadOffset) = *(reinterpret_cast<T const*>(inputBuffer) + yOffset + shape::getIndexOffset(j, tadOnlyInputShapeInfo, inputLength));
}
}
}
//////////////////////////////////////////////////////////////////////////
// Returns a batched matrix tensor with new batched diagonal values.
// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag
template <typename T>
static int _matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) {
cudaStream_t* stream = context->getCudaStream();
//auto listOut = output->allTensorsAlongDimension({output->rankOf() - 2, output->rankOf() - 1});
//auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 1});
//auto repeatDelta = shape::prodLong(newShape.data(), rank) / this->lengthOf();
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), {input->rankOf() - 1});
const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs(input->getShapeInfo(), dimsToExclude); //this->tensorsAlongDimension({dimension});
//printf("Repeat delta %lld, numTads %lld\n", repeatDelta, numTads);
//tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets;
std::vector<int> inputDims({input->rankOf() - 1});
std::vector<int> outputDims({output->rankOf() - 2, output->rankOf() - 1});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), inputDims);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), outputDims);
if (!input->isActualOnDeviceSide())
input->syncToDevice();
if (!output->isActualOnDeviceSide())
output->syncToDevice();
// create cuda stream and LaunchContext
cudaError_t cudaResult;
dim3 launchDims(256, 512, 8192);
matrixDiagKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(input->getSpecialBuffer(), output->getSpecialBuffer(), numTads, input->sizeAt(-1), packX.specialShapeInfo(), packX.specialOffsets(), packZ.specialShapeInfo(), packZ.specialOffsets());
return Status::OK();
}
int matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiag, (context, input, output), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template int _matrixDiag, (nd4j::LaunchContext * context, const NDArray* input, NDArray* output), LIBND4J_TYPES);
}
}
}

View File

@ -28,8 +28,7 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output); void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad);
} }
} }

View File

@ -1,34 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#ifndef __MATRIX_DIAG_HELPERS__
#define __MATRIX_DIAG_HELPERS__
#include <op_boilerplate.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
int matrixDiag(nd4j::LaunchContext * context, NDArray const* input, NDArray* output);
}
}
}
#endif

View File

@ -117,9 +117,9 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) {
auto v = result->at(0); auto v = result->at(0);
auto i = result->at(1); auto i = result->at(1);
v->printIndexedBuffer("Values"); // v->printIndexedBuffer("Values");
i->printIndexedBuffer("Indices"); // i->printIndexedBuffer("Indices");
i->printShapeInfo("Indices shape"); // i->printShapeInfo("Indices shape");
ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expV.equalsTo(v)); ASSERT_TRUE(expV.equalsTo(v));
@ -145,12 +145,12 @@ TEST_F(DeclarableOpsTests3, Test_Unique_2) {
auto i = result->at(1); auto i = result->at(1);
auto c = result->at(2); auto c = result->at(2);
v->printShapeInfo(); // v->printShapeInfo();
v->printIndexedBuffer("Values"); // v->printIndexedBuffer("Values");
i->printShapeInfo(); // i->printShapeInfo();
i->printIndexedBuffer("Indices"); // i->printIndexedBuffer("Indices");
c->printShapeInfo(); // c->printShapeInfo();
c->printIndexedBuffer("Counts"); // c->printIndexedBuffer("Counts");
ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expV.equalsTo(v)); ASSERT_TRUE(expV.equalsTo(v));
@ -200,11 +200,11 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) {
auto result1 = op.execute({&x}, {1.}, {1}); auto result1 = op.execute({&x}, {1.}, {1});
ASSERT_EQ(result1->status(), ND4J_STATUS_OK); ASSERT_EQ(result1->status(), ND4J_STATUS_OK);
auto z1 = result1->at(0); auto z1 = result1->at(0);
z1->printIndexedBuffer("Z1"); // z1->printIndexedBuffer("Z1");
auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false); auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false);
exp1.printIndexedBuffer("EXP1"); // exp1.printIndexedBuffer("EXP1");
z1->printShapeInfo("Z1 shape"); // z1->printShapeInfo("Z1 shape");
exp1.printShapeInfo("EXP1 shape"); // exp1.printShapeInfo("EXP1 shape");
ASSERT_TRUE(exp1.isSameShape(z1)); ASSERT_TRUE(exp1.isSameShape(z1));
ASSERT_TRUE(exp1.equalsTo(z1)); ASSERT_TRUE(exp1.equalsTo(z1));
@ -714,7 +714,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
auto exp = MmulHelper::mmul(&x, &y); auto exp = MmulHelper::mmul(&x, &y);
exp->printShapeInfo("exp shape"); // exp->printShapeInfo("exp shape");
nd4j::ops::batched_gemm op; nd4j::ops::batched_gemm op;
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});

View File

@ -79,7 +79,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_2) {
sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true); sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();
k.printIndexedBuffer("KEYS"); // k.printIndexedBuffer("KEYS");
ASSERT_EQ(ek, k); ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);
} }
@ -98,8 +98,8 @@ TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();
k.printIndexedBuffer("k"); // k.printIndexedBuffer("k");
v.printIndexedBuffer("v"); // v.printIndexedBuffer("v");
ASSERT_EQ(ek, k); ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);

View File

@ -1562,8 +1562,8 @@ public class DifferentialFunctionFactory {
} }
public SDVariable eluBp(SDVariable in, SDVariable epsilon) { public SDVariable eluBp(SDVariable in, SDVariable epsilon, double alpha) {
return new EluBp(sameDiff(), in, epsilon).outputVariable(); return new EluBp(sameDiff(), in, epsilon, alpha).outputVariable();
} }

View File

@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU; import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.indexing.conditions.Conditions;
/** /**
* f(x) = alpha * (exp(x) - 1.0); x < 0 * f(x) = alpha * (exp(x) - 1.0); x < 0
@ -55,15 +53,7 @@ public class ActivationELU extends BaseActivationFunction {
*/ */
@Override @Override
public INDArray getActivation(INDArray in, boolean training) { public INDArray getActivation(INDArray in, boolean training) {
// no support in ELU native to override alpha return Nd4j.exec(new ELU(in, in, alpha))[0];
if (this.alpha != 1.00) {
INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0];
alphaMultiple.muli(alpha);
BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
} else {
Nd4j.getExecutioner().execAndReturn(new ELU(in));
}
return in;
} }
/* /*

View File

@ -1195,7 +1195,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
return this; return this;
} }
@ -3089,12 +3088,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return mmuli(other, result); return mmuli(other, result);
} }
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @return the result of the divide
*/
@Override @Override
public INDArray div(INDArray other) { public INDArray div(INDArray other) {
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
@ -3104,25 +3097,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* copy (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @param result the result ndarray
* @return the result of the divide
*/
@Override @Override
public INDArray div(INDArray other, INDArray result) { public INDArray div(INDArray other, INDArray result) {
validateNumericalArray("div", true); validateNumericalArray("div", true);
return divi(other, result); return divi(other, result);
} }
/**
* copy (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @return the result of the addition
*/
@Override @Override
public INDArray mul(INDArray other) { public INDArray mul(INDArray other) {
validateNumericalArray("mul", false); validateNumericalArray("mul", false);
@ -3134,24 +3114,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* copy (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @param result the result ndarray
* @return the result of the multiplication
*/
@Override @Override
public INDArray mul(INDArray other, INDArray result) { public INDArray mul(INDArray other, INDArray result) {
return muli(other, result); return muli(other, result);
} }
/**
* copy subtraction of two matrices
*
* @param other the second ndarray to subtract
* @return the result of the addition
*/
@Override @Override
public INDArray sub(INDArray other) { public INDArray sub(INDArray other) {
validateNumericalArray("sub", false); validateNumericalArray("sub", false);
@ -3162,24 +3129,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* copy subtraction of two matrices
*
* @param other the second ndarray to subtract
* @param result the result ndarray
* @return the result of the subtraction
*/
@Override @Override
public INDArray sub(INDArray other, INDArray result) { public INDArray sub(INDArray other, INDArray result) {
return subi(other, result); return subi(other, result);
} }
/**
* copy addition of two matrices
*
* @param other the second ndarray to add
* @return the result of the addition
*/
@Override @Override
public INDArray add(INDArray other) { public INDArray add(INDArray other) {
validateNumericalArray("add", false); validateNumericalArray("add", false);
@ -3190,65 +3144,29 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* copy addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override @Override
public INDArray add(INDArray other, INDArray result) { public INDArray add(INDArray other, INDArray result) {
validateNumericalArray("add", false); validateNumericalArray("add", false);
return addi(other, result); return addi(other, result);
} }
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param transpose the transpose status of each ndarray
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmuli(INDArray other, MMulTranspose transpose) { public INDArray mmuli(INDArray other, MMulTranspose transpose) {
validateNumericalArray("mmuli", false); validateNumericalArray("mmuli", false);
return dup().mmuli(other, this,transpose); return dup().mmuli(other, this,transpose);
} }
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmuli(INDArray other) { public INDArray mmuli(INDArray other) {
validateNumericalArray("mmuli", false); validateNumericalArray("mmuli", false);
return dup().mmuli(other, this); return dup().mmuli(other, this);
} }
/**
* Perform an in place matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) { public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
return transpose.exec(this, other, result); return transpose.exec(this, other, result);
} }
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmuli(INDArray other, INDArray result) { public INDArray mmuli(INDArray other, INDArray result) {
validateNumericalArray("mmuli", false); validateNumericalArray("mmuli", false);
@ -3347,24 +3265,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.create(shape, stride); return Nd4j.create(shape, stride);
} }
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @return the result of the divide
*/
@Override @Override
public INDArray divi(INDArray other) { public INDArray divi(INDArray other) {
return divi(other, this); return divi(other, this);
} }
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @param result the result ndarray
* @return the result of the divide
*/
@Override @Override
public INDArray divi(INDArray other, INDArray result) { public INDArray divi(INDArray other, INDArray result) {
validateNumericalArray("divi", false); validateNumericalArray("divi", false);
@ -3373,24 +3278,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* in place (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @return the result of the multiplication
*/
@Override @Override
public INDArray muli(INDArray other) { public INDArray muli(INDArray other) {
return muli(other, this); return muli(other, this);
} }
/**
* in place (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @param result the result ndarray
* @return the result of the multiplication
*/
@Override @Override
public INDArray muli(INDArray other, INDArray result) { public INDArray muli(INDArray other, INDArray result) {
validateNumericalArray("muli", false); validateNumericalArray("muli", false);
@ -3399,12 +3291,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* in place subtraction of two matrices
*
* @param other the second ndarray to subtract
* @return the result of the addition
*/
@Override @Override
public INDArray subi(INDArray other) { public INDArray subi(INDArray other) {
return subi(other, this); return subi(other, this);
@ -3425,24 +3311,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @return the result of the addition
*/
@Override @Override
public INDArray addi(INDArray other) { public INDArray addi(INDArray other) {
return addi(other, this); return addi(other, this);
} }
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override @Override
public INDArray addi(INDArray other, INDArray result) { public INDArray addi(INDArray other, INDArray result) {
validateNumericalArray("addi", false); validateNumericalArray("addi", false);
@ -3451,25 +3324,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* Returns the normmax along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm1 along the specified dimension
*/
@Override @Override
public INDArray normmax(boolean keepDims, int... dimension) { public INDArray normmax(boolean keepDims, int... dimension) {
validateNumericalArray("normmax", false); validateNumericalArray("normmax", false);
return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension));
} }
/**
* Returns the normmax along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @return the norm1 along the specified dimension
*/
@Override @Override
public INDArray normmax(int... dimension) { public INDArray normmax(int... dimension) {
return normmax(false, dimension); return normmax(false, dimension);
@ -4071,49 +3931,23 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return reshape(Nd4j.order(), shape); return reshape(Nd4j.order(), shape);
} }
/**
* Returns the product along a given dimension
*
* @param dimension the dimension to getScalar the product along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the product along the specified dimension
*/
@Override @Override
public INDArray prod(boolean keepDims, int... dimension) { public INDArray prod(boolean keepDims, int... dimension) {
validateNumericalArray("prod", false); validateNumericalArray("prod", false);
return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension));
} }
/**
* Returns the product along a given dimension
*
* @param dimension the dimension to getScalar the product along
* @return the product along the specified dimension
*/
@Override @Override
public INDArray prod(int... dimension) { public INDArray prod(int... dimension) {
return prod(false, dimension); return prod(false, dimension);
} }
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray mean(boolean keepDims, int... dimension) { public INDArray mean(boolean keepDims, int... dimension) {
validateNumericalArray("mean", false); validateNumericalArray("mean", false);
return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension));
} }
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray mean(int... dimension) { public INDArray mean(int... dimension) {
return mean(false, dimension); return mean(false, dimension);
@ -4136,50 +3970,24 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return mean(result, false, dimension); return mean(result, false, dimension);
} }
/**
* Returns the overall variance of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray var(int... dimension) { public INDArray var(int... dimension) {
validateNumericalArray("var", false); validateNumericalArray("var", false);
return Nd4j.getExecutioner().exec(new Variance(this, dimension)); return Nd4j.getExecutioner().exec(new Variance(this, dimension));
} }
/**
* Returns the overall variance of this ndarray
*
* @param biasCorrected boolean on whether to apply corrected bias
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray var(boolean biasCorrected, int... dimension) { public INDArray var(boolean biasCorrected, int... dimension) {
validateNumericalArray("var", false); validateNumericalArray("var", false);
return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension)); return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension));
} }
/**
* Returns the overall max of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray max(boolean keepDims, int... dimension) { public INDArray max(boolean keepDims, int... dimension) {
validateNumericalArray("max", false); validateNumericalArray("max", false);
return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension));
} }
/**
* Returns the overall max of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray max(int... dimension) { public INDArray max(int... dimension) {
return max(false, dimension); return max(false, dimension);
@ -4191,25 +3999,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new AMax(this, dimension)); return Nd4j.getExecutioner().exec(new AMax(this, dimension));
} }
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray min(boolean keepDims, int... dimension) { public INDArray min(boolean keepDims, int... dimension) {
validateNumericalArray("min", false); validateNumericalArray("min", false);
return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension));
} }
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray min(int... dimension) { public INDArray min(int... dimension) {
return min(false, dimension); return min(false, dimension);
@ -4290,39 +4085,17 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return sum(result, false, dimension); return sum(result, false, dimension);
} }
/**
* Returns the norm1 along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @return the norm1 along the specified dimension
*/
@Override @Override
public INDArray norm1(int... dimension) { public INDArray norm1(int... dimension) {
return norm1(false, dimension); return norm1(false, dimension);
} }
/**
* Returns the norm1 along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm1 along the specified dimension
*/
@Override @Override
public INDArray norm1(boolean keepDims, int... dimension) { public INDArray norm1(boolean keepDims, int... dimension) {
validateNumericalArray("norm1", false); validateNumericalArray("norm1", false);
return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension));
} }
/**
* Standard deviation of an ndarray along a dimension
*
* @param dimension the dimension to getScalar the std along
* @return the standard deviation along a particular dimension
*/
@Override @Override
public INDArray std(int... dimension) { public INDArray std(int... dimension) {
return std(true, dimension); return std(true, dimension);
@ -4345,32 +4118,17 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0); return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0);
} }
/**
* Returns the norm2 along the specified dimension
*
* @param dimension the dimension to getScalar the norm2 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm2 along the specified dimension
*/
@Override @Override
public INDArray norm2(boolean keepDims, int... dimension) { public INDArray norm2(boolean keepDims, int... dimension) {
validateNumericalArray("norm2", false); validateNumericalArray("norm2", false);
return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension));
} }
/**
* Returns the norm2 along the specified dimension
*
* @param dimension the dimension to getScalar the norm2 along
* @return the norm2 along the specified dimension
*/
@Override @Override
public INDArray norm2(int... dimension) { public INDArray norm2(int... dimension) {
return norm2(false, dimension); return norm2(false, dimension);
} }
/** /**
* Number of columns (shape[1]), throws an exception when * Number of columns (shape[1]), throws an exception when
* called when not 2d * called when not 2d

View File

@ -1232,8 +1232,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
return null; return null;
} }
@Override @Override
public INDArray normmax(boolean keepDims, int... dimension) { public INDArray normmax(boolean keepDims, int... dimension) {
return null; return null;

View File

@ -1404,7 +1404,13 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray add(INDArray other, INDArray result); INDArray add(INDArray other, INDArray result);
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param transpose the transpose status of each ndarray
* @return the result of the matrix multiplication
*/
INDArray mmuli(INDArray other, MMulTranspose transpose); INDArray mmuli(INDArray other, MMulTranspose transpose);
/** /**
@ -1415,7 +1421,13 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray mmuli(INDArray other); INDArray mmuli(INDArray other);
/**
* Perform an in place matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose); INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose);
/** /**
@ -1497,7 +1509,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray addi(INDArray other, INDArray result); INDArray addi(INDArray other, INDArray result);
/** /**
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s) * Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
* *
@ -1506,7 +1517,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray normmax(int... dimension); INDArray normmax(int... dimension);
/** /**
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s) * Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
* *
@ -1585,7 +1595,7 @@ public interface INDArray extends Serializable, AutoCloseable {
/** /**
* Calculate the standard deviation for the entire array * Calculate the standard deviation for the entire array
* *
* @return * @return standard deviation
*/ */
Number stdNumber(); Number stdNumber();

View File

@ -33,8 +33,9 @@ public class EluBp extends DynamicCustomOp {
public EluBp(){ } public EluBp(){ }
public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){ public EluBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){
super(sd, new SDVariable[]{input, gradient}); super(sd, new SDVariable[]{input, gradient});
addTArgument(alpha);
} }
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) { public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {

View File

@ -23,13 +23,9 @@ import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* ELU: Exponential Linear Unit (alpha=1.0)<br> * ELU: Exponential Linear Unit (alpha=1.0)<br>
@ -41,19 +37,31 @@ import java.util.Map;
* @author Alex Black * @author Alex Black
*/ */
public class ELU extends DynamicCustomOp { public class ELU extends DynamicCustomOp {
public static final double DEFAULT_ALPHA = 1.0;
protected double alpha;
public ELU(SameDiff sameDiff, SDVariable i_v) { public ELU(SameDiff sameDiff, SDVariable i_v) {
super(sameDiff, new SDVariable[]{i_v}); super(sameDiff, new SDVariable[]{i_v});
this.alpha = DEFAULT_ALPHA;
addTArgument(alpha);
} }
public ELU() { public ELU() {
} }
public ELU(INDArray x, INDArray z) { public ELU(INDArray x, INDArray z) {
this(x, z, DEFAULT_ALPHA);
}
public ELU(INDArray x, INDArray z, double alpha) {
super(null, wrapOrNull(x), wrapOrNull(z)); super(null, wrapOrNull(x), wrapOrNull(z));
this.alpha = alpha;
addTArgument(alpha);
} }
public ELU(INDArray x) { public ELU(INDArray x) {
this(x, null); this(x, null, DEFAULT_ALPHA);
} }
@Override @Override
@ -75,7 +83,7 @@ public class ELU extends DynamicCustomOp {
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
//ELU: e^x-1 if x<0, x otherwise //ELU: e^x-1 if x<0, x otherwise
//dL/dIn = dL/Out * dOut/dIn //dL/dIn = dL/Out * dOut/dIn
return Collections.singletonList(f().eluBp(arg(), i_v.get(0))); return Collections.singletonList(f().eluBp(arg(), i_v.get(0), alpha));
} }
@Override @Override

View File

@ -18,6 +18,7 @@ package org.nd4j.jita.allocator.pointers.cuda;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.val;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.linalg.exception.ND4JException; import org.nd4j.linalg.exception.ND4JException;
@ -69,8 +70,9 @@ public class cudaEvent_t extends CudaPointer {
if (res == 0) if (res == 0)
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]"); throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0) val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage()); if (code != 0)
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
} }
} }
@ -78,8 +80,9 @@ public class cudaEvent_t extends CudaPointer {
if (!isDestroyed()) { if (!isDestroyed()) {
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream); int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream);
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0) val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage()); if (code != 0)
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
} }
} }
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.jita.handler.impl; package org.nd4j.jita.handler.impl;
import org.nd4j.nativeblas.OpaqueLaunchContext;
import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table; import org.nd4j.shade.guava.collect.Table;
import lombok.Getter; import lombok.Getter;
@ -105,6 +106,8 @@ public class CudaZeroHandler implements MemoryHandler {
private final AllocationStatus INITIAL_LOCATION; private final AllocationStatus INITIAL_LOCATION;
private final List<cublasHandle_t> cublasHandles = new ArrayList<>();
private final AffinityManager affinityManager = Nd4j.getAffinityManager(); private final AffinityManager affinityManager = Nd4j.getAffinityManager();
/* /*
@ -162,6 +165,7 @@ public class CudaZeroHandler implements MemoryHandler {
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices(); int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
for (int i = 0; i < numDevices; i++) { for (int i = 0; i < numDevices; i++) {
deviceAllocations.add(new ConcurrentHashMap<Long, Long>()); deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
cublasHandles.add(null);
} }
if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) { if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) {
@ -1176,6 +1180,25 @@ public class CudaZeroHandler implements MemoryHandler {
return getCudaContext(); return getCudaContext();
} }
//
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
try {
lock.writeLock().lock();
if (cublasHandles.get(deviceId) == null) {
cublasHandles.remove(deviceId);
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
}
return cublasHandles.get(deviceId);
} finally {
lock.writeLock().unlock();
}
}
/** /**
* This method returns CudaContext for current thread. If context doesn't exist - it gets created first. * This method returns CudaContext for current thread. If context doesn't exist - it gets created first.
* @return * @return
@ -1183,8 +1206,6 @@ public class CudaZeroHandler implements MemoryHandler {
public CudaContext getCudaContext() { public CudaContext getCudaContext() {
val lc = nativeOps.defaultLaunchContext(); val lc = nativeOps.defaultLaunchContext();
// TODO: maybe make ThreadLocal cache for context?
return CudaContext.builder() return CudaContext.builder()
.bufferScalar(nativeOps.lcScalarPointer(lc)) .bufferScalar(nativeOps.lcScalarPointer(lc))
.bufferReduction(nativeOps.lcReductionPointer(lc)) .bufferReduction(nativeOps.lcReductionPointer(lc))
@ -1192,7 +1213,7 @@ public class CudaZeroHandler implements MemoryHandler {
.bufferSpecial(nativeOps.lcScalarPointer(lc)) .bufferSpecial(nativeOps.lcScalarPointer(lc))
.oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc)))
.specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc)))
.cublasHandle(new cublasHandle_t(nativeOps.lcBlasHandle(lc))) .cublasHandle(getCudaCublasHandle(lc))
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
.build(); .build();
} }

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.jcublas.blas; package org.nd4j.linalg.jcublas.blas;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.bytedeco.javacpp.DoublePointer; import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.FloatPointer;
@ -52,6 +53,7 @@ import static org.nd4j.linalg.jcublas.blas.CudaBlas.*;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j
public class JcublasLevel3 extends BaseLevel3 { public class JcublasLevel3 extends BaseLevel3 {
private Allocator allocator = AtomicAllocator.getInstance(); private Allocator allocator = AtomicAllocator.getInstance();
private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas(); private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas();
@ -78,7 +80,7 @@ public class JcublasLevel3 extends BaseLevel3 {
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture(); int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch == 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) { if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch >= 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) {
// on these selected archs we run with cublasHgemm // on these selected archs we run with cublasHgemm
__half alphaHalf = new __half(); __half alphaHalf = new __half();
__half betaHalf = new __half(); __half betaHalf = new __half();
@ -96,7 +98,11 @@ public class JcublasLevel3 extends BaseLevel3 {
new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda, new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda,
(ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta), (ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta),
(ShortPointer) cCPointer.getDevicePointer(), 2, ldc); (ShortPointer) cCPointer.getDevicePointer(), 2, ldc);
} }
ctx.getOldStream().synchronize();
} }
allocator.registerAction(ctx, C, A, B); allocator.registerAction(ctx, C, A, B);
@ -114,18 +120,24 @@ public class JcublasLevel3 extends BaseLevel3 {
val ctx = allocator.getFlowController().prepareAction(C, A, B); val ctx = allocator.getFlowController().prepareAction(C, A, B);
//log.info("Synchronizing CUDA stream");
ctx.getOldStream().synchronize();
val cAPointer = new CublasPointer(A, ctx); val cAPointer = new CublasPointer(A, ctx);
val cBPointer = new CublasPointer(B, ctx); val cBPointer = new CublasPointer(B, ctx);
val cCPointer = new CublasPointer(C, ctx); val cCPointer = new CublasPointer(C, ctx);
val handle = ctx.getCublasHandle(); val handle = ctx.getCublasHandle();
synchronized (handle) { synchronized (handle) {
//log.info("Handle: {}; Stream: {}", handle.address(), ctx.getCublasStream().address());
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K, cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda, new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
(FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta), (FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta),
(FloatPointer) cCPointer.getDevicePointer(), ldc); (FloatPointer) cCPointer.getDevicePointer(), ldc);
ctx.getOldStream().synchronize();
} }
allocator.registerAction(ctx, C, A, B); allocator.registerAction(ctx, C, A, B);
@ -244,6 +256,8 @@ public class JcublasLevel3 extends BaseLevel3 {
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda, new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
(DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta), (DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta),
(DoublePointer) cCPointer.getDevicePointer(), ldc); (DoublePointer) cCPointer.getDevicePointer(), ldc);
ctx.getOldStream().synchronize();
} }
allocator.registerAction(ctx, C, A, B); allocator.registerAction(ctx, C, A, B);

View File

@ -2548,6 +2548,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
@ -2562,6 +2565,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
@ -2577,6 +2583,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) { public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length); OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
@ -2590,6 +2599,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) { public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length); OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)

View File

@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
* @param writeList * @param writeList
* @param readList * @param readList
*/ */
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
/** /**

View File

@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
* @param writeList * @param writeList
* @param readList * @param readList
*/ */
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
/** /**
@ -16982,7 +16985,19 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #endif // #endif
/** /**
* Returns a batched matrix tensor with new batched diagonal values. * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
*
* Input arrays:
* input: input array, considered as batch of matrices
* diagonal: array containing elements to be inserted into input array,
* following rank condition should be satisfied: diagonal_rank = input_rank - 1,
* the shapes of diagonal and input arrays must be equal except last dimension of input array,
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
*
* Output array:
* has the same shape as input, corresponding diagonal elements are substituted
*/ */
// #if NOT_EXCLUDED(OP_matrix_set_diag) // #if NOT_EXCLUDED(OP_matrix_set_diag)
@Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp { @Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {