Snapshot update (#8194)
* fix double consumption of rng on cpu Signed-off-by: raver119 <raver119@gmail.com> * Shyrma docs (#222) * - documenting and profiling matrix_set_diag cuda kernel Signed-off-by: Yurii <yurii@skymind.io> * - correct formula of pnorm pooling in cuda 2d/3d kernels - remove helper matrix_diag which duplicates work of helper matrix_set_diag Signed-off-by: Yurii <yurii@skymind.io> * cublasHandle sharing + lock Signed-off-by: raver119 <raver119@gmail.com> * cublasHandle sharing + lock Signed-off-by: raver119 <raver119@gmail.com> * Documentation from serialization/deserialization in NLP (#221) * refactoring Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Javadocs Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Javadoc fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Cleanup Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * dedicated lock for getCudaCublasHandle Signed-off-by: raver119 <raver119@gmail.com> * Small fixes (#223) Signed-off-by: AlexDBlack <blacka101@gmail.com> * ELU DL4J fixes (#224) Signed-off-by: AlexDBlack <blacka101@gmail.com> * javadoc (#225) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Small test compilation fix (#226) Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8182 remove spark version suffix (#227) Signed-off-by: AlexDBlack <blacka101@gmail.com> * [WIP] Thread safety (#229) * sync after cublas*gemm Signed-off-by: raver119 <raver119@gmail.com> * mutex for CublasHelper Signed-off-by: raver119 <raver119@gmail.com> * don't store cublasHandle in LaunchContext, it's per-device anyway Signed-off-by: raver119 <raver119@gmail.com> * some printout Signed-off-by: raver119 <raver119@gmail.com> * check for field instead Signed-off-by: raver119 <raver119@gmail.com> * pew-pew Signed-off-by: raver119 <raver119@gmail.com> * don't release ContextBuffers until device changed Signed-off-by: raver119 <raver119@gmail.com> * small tweak Signed-off-by: raver119 <raver119@gmail.com> * some logging in sgemm Signed-off-by: raver119 <raver119@gmail.com> * stream sync Signed-off-by: raver119 <raver119@gmail.com> * some more logging Signed-off-by: raver119 <raver119@gmail.com> * some more error checks Signed-off-by: raver119 <raver119@gmail.com> * one fancy test Signed-off-by: raver119 <raver119@gmail.com> * one fancy test Signed-off-by: raver119 <raver119@gmail.com> * minor AffinityManager fix Signed-off-by: raver119 <raver119@gmail.com> * cudaEvent error logging improvement Signed-off-by: raver119 <raver119@gmail.com> * ConstantHelper thread safety Signed-off-by: raver119 <raver119@gmail.com> * - minor corrections in ConstantTadHelper Signed-off-by: Yurii <yurii@skymind.io> * ConstantShapeHelper thread safety Signed-off-by: raver119 <raver119@gmail.com> * ConstantTadHelper.cu updated Signed-off-by: raver119 <raver119@gmail.com> * logging off Signed-off-by: raver119 <raver119@gmail.com> * logging off Signed-off-by: raver119 <raver119@gmail.com>master
parent
9d03bb9425
commit
7abc574eeb
|
@ -38,7 +38,7 @@
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.datavec</groupId>
|
<groupId>org.datavec</groupId>
|
||||||
<artifactId>datavec-spark-inference-server_2.11</artifactId>
|
<artifactId>datavec-spark-inference-server_2.11</artifactId>
|
||||||
<version>1.0.0_spark_2-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
<artifactId>datavec-spark-inference-server_2.11</artifactId>
|
<artifactId>datavec-spark-inference-server_2.11</artifactId>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
<version>1.0.0_spark_2-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
<name>datavec-spark-inference-server</name>
|
<name>datavec-spark-inference-server</name>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<version>1.0.0_spark_2-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
<artifactId>datavec-spark_2.11</artifactId>
|
<artifactId>datavec-spark_2.11</artifactId>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
package org.deeplearning4j;
|
||||||
|
|
||||||
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.junit.Ignore;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
|
@Ignore
|
||||||
|
public class RandomTests {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testReproduce() throws Exception {
|
||||||
|
|
||||||
|
final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
|
||||||
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
|
||||||
|
.layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10)
|
||||||
|
.activation(Activation.TANH).build())
|
||||||
|
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
|
||||||
|
LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10)
|
||||||
|
.activation(Activation.SOFTMAX).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
for (int e = 0; e < 3; e++) {
|
||||||
|
|
||||||
|
int nThreads = 10;
|
||||||
|
final CountDownLatch l = new CountDownLatch(nThreads);
|
||||||
|
for (int i = 0; i < nThreads; i++) {
|
||||||
|
final int j = i;
|
||||||
|
Thread t = new Thread(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try {
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf.clone());
|
||||||
|
net.init();
|
||||||
|
DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(10, false, 12345), 100);
|
||||||
|
net.fit(iter);
|
||||||
|
} catch (Throwable t) {
|
||||||
|
System.out.println("Thread failed: " + j);
|
||||||
|
t.printStackTrace();
|
||||||
|
} finally {
|
||||||
|
l.countDown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
t.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
l.await();
|
||||||
|
System.out.println("DONE " + e + "\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -833,14 +833,14 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
public void testB64_1() throws Exception {
|
public void testB64_1() throws Exception {
|
||||||
String wordA = "night";
|
String wordA = "night";
|
||||||
String wordB = "night day";
|
String wordB = "night day";
|
||||||
String encA = WordVectorSerializer.encodeB64(wordA);
|
String encA = WordVectorSerializer.ReadHelper.encodeB64(wordA);
|
||||||
String encB = WordVectorSerializer.encodeB64(wordB);
|
String encB = WordVectorSerializer.ReadHelper.encodeB64(wordB);
|
||||||
|
|
||||||
assertEquals(wordA, WordVectorSerializer.decodeB64(encA));
|
assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(encA));
|
||||||
assertEquals(wordB, WordVectorSerializer.decodeB64(encB));
|
assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(encB));
|
||||||
|
|
||||||
assertEquals(wordA, WordVectorSerializer.decodeB64(wordA));
|
assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(wordA));
|
||||||
assertEquals(wordB, WordVectorSerializer.decodeB64(wordB));
|
assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(wordB));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.apache.commons.io.LineIterator;
|
import org.apache.commons.io.LineIterator;
|
||||||
import org.apache.commons.io.output.CloseShieldOutputStream;
|
import org.apache.commons.io.output.CloseShieldOutputStream;
|
||||||
import org.deeplearning4j.exception.DL4JException;
|
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
|
@ -52,7 +51,6 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
import org.deeplearning4j.util.DL4JFileUtils;
|
import org.deeplearning4j.util.DL4JFileUtils;
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.compression.impl.NoOp;
|
import org.nd4j.compression.impl.NoOp;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
@ -68,8 +66,6 @@ import org.nd4j.util.OneTimeLogger;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.nio.file.Paths;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
@ -78,6 +74,80 @@ import java.util.zip.*;
|
||||||
/**
|
/**
|
||||||
* This is utility class, providing various methods for WordVectors serialization
|
* This is utility class, providing various methods for WordVectors serialization
|
||||||
*
|
*
|
||||||
|
* List of available serialization methods (please keep this list consistent with source code):
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>Serializers for Word2Vec:</li>
|
||||||
|
* {@link #writeWordVectors(WeightLookupTable, File)}
|
||||||
|
* {@link #writeWordVectors(WeightLookupTable, OutputStream)}
|
||||||
|
* {@link #writeWord2VecModel(Word2Vec, File)}
|
||||||
|
* {@link #writeWord2VecModel(Word2Vec, String)}
|
||||||
|
* {@link #writeWord2VecModel(Word2Vec, OutputStream)}
|
||||||
|
*
|
||||||
|
* <li>Deserializers for Word2Vec:</li>
|
||||||
|
* {@link #readWord2VecModel(File)}
|
||||||
|
* {@link #readWord2VecModel(String)}
|
||||||
|
* {@link #readWord2VecModel(File, boolean)}
|
||||||
|
* {@link #readWord2VecModel(String, boolean)}
|
||||||
|
* {@link #readAsBinaryNoLineBreaks(File)}
|
||||||
|
* {@link #readAsBinary(File)}
|
||||||
|
* {@link #readAsCsv(File)}
|
||||||
|
* {@link #readBinaryModel(File, boolean, boolean)}
|
||||||
|
* {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
|
||||||
|
* {@link #readWord2Vec(String, boolean)}
|
||||||
|
* {@link #readWord2Vec(File, boolean)}
|
||||||
|
* {@link #readWord2Vec(InputStream, boolean)}
|
||||||
|
*
|
||||||
|
* <li>Serializers for ParaVec:</li>
|
||||||
|
* {@link #writeParagraphVectors(ParagraphVectors, File)}
|
||||||
|
* {@link #writeParagraphVectors(ParagraphVectors, String)}
|
||||||
|
* {@link #writeParagraphVectors(ParagraphVectors, OutputStream)}
|
||||||
|
*
|
||||||
|
* <li>Deserializers for ParaVec:</li>
|
||||||
|
* {@link #readParagraphVectors(File)}
|
||||||
|
* {@link #readParagraphVectors(String)}
|
||||||
|
* {@link #readParagraphVectors(InputStream)}
|
||||||
|
*
|
||||||
|
* <li>Serializers for GloVe:</li>
|
||||||
|
* {@link #writeWordVectors(Glove, File)}
|
||||||
|
* {@link #writeWordVectors(Glove, String)}
|
||||||
|
* {@link #writeWordVectors(Glove, OutputStream)}
|
||||||
|
*
|
||||||
|
* <li>Adapters</li>
|
||||||
|
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
|
||||||
|
* {@link #fromPair(Pair)}
|
||||||
|
* {@link #loadTxt(File)}
|
||||||
|
*
|
||||||
|
* <li>Serializers to tSNE format</li>
|
||||||
|
* {@link #writeTsneFormat(Glove, INDArray, File)}
|
||||||
|
* {@link #writeTsneFormat(Word2Vec, INDArray, File)}
|
||||||
|
*
|
||||||
|
* <li>FastText serializer:</li>
|
||||||
|
* {@link #writeWordVectors(FastText, File)}
|
||||||
|
*
|
||||||
|
* <li>FastText deserializer:</li>
|
||||||
|
* {@link #readWordVectors(File)}
|
||||||
|
*
|
||||||
|
* <li>SequenceVectors serializers:</li>
|
||||||
|
* {@link #writeSequenceVectors(SequenceVectors, OutputStream)}
|
||||||
|
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, File)}
|
||||||
|
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, String)}
|
||||||
|
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, OutputStream)}
|
||||||
|
* {@link #writeLookupTable(WeightLookupTable, File)}
|
||||||
|
* {@link #writeVocabCache(VocabCache, File)}
|
||||||
|
* {@link #writeVocabCache(VocabCache, OutputStream)}
|
||||||
|
*
|
||||||
|
* <li>SequenceVectors deserializers:</li>
|
||||||
|
* {@link #readSequenceVectors(File, boolean)}
|
||||||
|
* {@link #readSequenceVectors(String, boolean)}
|
||||||
|
* {@link #readSequenceVectors(SequenceElementFactory, File)}
|
||||||
|
* {@link #readSequenceVectors(InputStream, boolean)}
|
||||||
|
* {@link #readSequenceVectors(SequenceElementFactory, InputStream)}
|
||||||
|
* {@link #readLookupTable(File)}
|
||||||
|
* {@link #readLookupTable(InputStream)}
|
||||||
|
*
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
* @author raver119
|
* @author raver119
|
||||||
* @author alexander@skymind.io
|
* @author alexander@skymind.io
|
||||||
|
@ -97,7 +167,7 @@ public class WordVectorSerializer {
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
* @throws NumberFormatException
|
* @throws NumberFormatException
|
||||||
*/
|
*/
|
||||||
private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
|
/*private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
|
||||||
InMemoryLookupTable lookupTable;
|
InMemoryLookupTable lookupTable;
|
||||||
VocabCache cache;
|
VocabCache cache;
|
||||||
INDArray syn0;
|
INDArray syn0;
|
||||||
|
@ -142,7 +212,7 @@ public class WordVectorSerializer {
|
||||||
ret.setLookupTable(lookupTable);
|
ret.setLookupTable(lookupTable);
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Read a binary word2vec file.
|
* Read a binary word2vec file.
|
||||||
|
@ -173,8 +243,8 @@ public class WordVectorSerializer {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName())
|
try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName())
|
||||||
? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
|
? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
|
||||||
DataInputStream dis = new DataInputStream(bis)) {
|
DataInputStream dis = new DataInputStream(bis)) {
|
||||||
words = Integer.parseInt(readString(dis));
|
words = Integer.parseInt(ReadHelper.readString(dis));
|
||||||
size = Integer.parseInt(readString(dis));
|
size = Integer.parseInt(ReadHelper.readString(dis));
|
||||||
syn0 = Nd4j.create(words, size);
|
syn0 = Nd4j.create(words, size);
|
||||||
cache = new AbstractCache<>();
|
cache = new AbstractCache<>();
|
||||||
|
|
||||||
|
@ -188,11 +258,11 @@ public class WordVectorSerializer {
|
||||||
float[] vector = new float[size];
|
float[] vector = new float[size];
|
||||||
for (int i = 0; i < words; i++) {
|
for (int i = 0; i < words; i++) {
|
||||||
|
|
||||||
word = readString(dis);
|
word = ReadHelper.readString(dis);
|
||||||
log.trace("Loading " + word + " with word " + i);
|
log.trace("Loading " + word + " with word " + i);
|
||||||
|
|
||||||
for (int j = 0; j < size; j++) {
|
for (int j = 0; j < size; j++) {
|
||||||
vector[j] = readFloat(dis);
|
vector[j] = ReadHelper.readFloat(dis);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cache.containsWord(word))
|
if (cache.containsWord(word))
|
||||||
|
@ -236,64 +306,6 @@ public class WordVectorSerializer {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Read a float from a data input stream Credit to:
|
|
||||||
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
|
||||||
*
|
|
||||||
* @param is
|
|
||||||
* @return
|
|
||||||
* @throws IOException
|
|
||||||
*/
|
|
||||||
public static float readFloat(InputStream is) throws IOException {
|
|
||||||
byte[] bytes = new byte[4];
|
|
||||||
is.read(bytes);
|
|
||||||
return getFloat(bytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Read a string from a data input stream Credit to:
|
|
||||||
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
|
||||||
*
|
|
||||||
* @param b
|
|
||||||
* @return
|
|
||||||
* @throws IOException
|
|
||||||
*/
|
|
||||||
public static float getFloat(byte[] b) {
|
|
||||||
int accum = 0;
|
|
||||||
accum = accum | (b[0] & 0xff) << 0;
|
|
||||||
accum = accum | (b[1] & 0xff) << 8;
|
|
||||||
accum = accum | (b[2] & 0xff) << 16;
|
|
||||||
accum = accum | (b[3] & 0xff) << 24;
|
|
||||||
return Float.intBitsToFloat(accum);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Read a string from a data input stream Credit to:
|
|
||||||
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
|
||||||
*
|
|
||||||
* @param dis
|
|
||||||
* @return
|
|
||||||
* @throws IOException
|
|
||||||
*/
|
|
||||||
public static String readString(DataInputStream dis) throws IOException {
|
|
||||||
byte[] bytes = new byte[MAX_SIZE];
|
|
||||||
byte b = dis.readByte();
|
|
||||||
int i = -1;
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
while (b != 32 && b != 10) {
|
|
||||||
i++;
|
|
||||||
bytes[i] = b;
|
|
||||||
b = dis.readByte();
|
|
||||||
if (i == 49) {
|
|
||||||
sb.append(new String(bytes, "UTF-8"));
|
|
||||||
i = -1;
|
|
||||||
bytes = new byte[MAX_SIZE];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sb.append(new String(bytes, 0, i + 1, "UTF-8"));
|
|
||||||
return sb.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method writes word vectors to the given path.
|
* This method writes word vectors to the given path.
|
||||||
* Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network.
|
* Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network.
|
||||||
|
@ -355,7 +367,7 @@ public class WordVectorSerializer {
|
||||||
val builder = new StringBuilder();
|
val builder = new StringBuilder();
|
||||||
|
|
||||||
val l = element.getLabel();
|
val l = element.getLabel();
|
||||||
builder.append(encodeB64(l)).append(" ");
|
builder.append(ReadHelper.encodeB64(l)).append(" ");
|
||||||
val vec = lookupTable.vector(element.getLabel());
|
val vec = lookupTable.vector(element.getLabel());
|
||||||
for (int i = 0; i < vec.length(); i++) {
|
for (int i = 0; i < vec.length(); i++) {
|
||||||
builder.append(vec.getDouble(i));
|
builder.append(vec.getDouble(i));
|
||||||
|
@ -518,7 +530,7 @@ public class WordVectorSerializer {
|
||||||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
|
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
|
||||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||||
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
|
StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
|
||||||
for (int code : word.getCodes()) {
|
for (int code : word.getCodes()) {
|
||||||
builder.append(code).append(" ");
|
builder.append(code).append(" ");
|
||||||
}
|
}
|
||||||
|
@ -536,7 +548,7 @@ public class WordVectorSerializer {
|
||||||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
|
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
|
||||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||||
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
|
StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
|
||||||
for (int point : word.getPoints()) {
|
for (int point : word.getPoints()) {
|
||||||
builder.append(point).append(" ");
|
builder.append(point).append(" ");
|
||||||
}
|
}
|
||||||
|
@ -554,7 +566,7 @@ public class WordVectorSerializer {
|
||||||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
|
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
|
||||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||||
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ")
|
StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ")
|
||||||
.append(word.getElementFrequency()).append(" ")
|
.append(word.getElementFrequency()).append(" ")
|
||||||
.append(vectors.getVocab().docAppearedIn(word.getLabel()));
|
.append(vectors.getVocab().docAppearedIn(word.getLabel()));
|
||||||
|
|
||||||
|
@ -638,7 +650,7 @@ public class WordVectorSerializer {
|
||||||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
|
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
|
||||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||||
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
|
StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
|
||||||
for (int code : word.getCodes()) {
|
for (int code : word.getCodes()) {
|
||||||
builder.append(code).append(" ");
|
builder.append(code).append(" ");
|
||||||
}
|
}
|
||||||
|
@ -656,7 +668,7 @@ public class WordVectorSerializer {
|
||||||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
|
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
|
||||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||||
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
|
StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
|
||||||
for (int point : word.getPoints()) {
|
for (int point : word.getPoints()) {
|
||||||
builder.append(point).append(" ");
|
builder.append(point).append(" ");
|
||||||
}
|
}
|
||||||
|
@ -677,7 +689,7 @@ public class WordVectorSerializer {
|
||||||
StringBuilder builder = new StringBuilder();
|
StringBuilder builder = new StringBuilder();
|
||||||
for (VocabWord word : vectors.getVocab().tokens()) {
|
for (VocabWord word : vectors.getVocab().tokens()) {
|
||||||
if (word.isLabel())
|
if (word.isLabel())
|
||||||
builder.append(encodeB64(word.getLabel())).append("\n");
|
builder.append(ReadHelper.encodeB64(word.getLabel())).append("\n");
|
||||||
}
|
}
|
||||||
IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8);
|
IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
@ -688,7 +700,7 @@ public class WordVectorSerializer {
|
||||||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
|
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
|
||||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||||
builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency())
|
builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency())
|
||||||
.append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel()));
|
.append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel()));
|
||||||
|
|
||||||
writer.println(builder.toString().trim());
|
writer.println(builder.toString().trim());
|
||||||
|
@ -744,7 +756,7 @@ public class WordVectorSerializer {
|
||||||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) {
|
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) {
|
||||||
String line;
|
String line;
|
||||||
while ((line = reader.readLine()) != null) {
|
while ((line = reader.readLine()) != null) {
|
||||||
VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim()));
|
VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim()));
|
||||||
if (word != null) {
|
if (word != null) {
|
||||||
word.markAsLabel(true);
|
word.markAsLabel(true);
|
||||||
}
|
}
|
||||||
|
@ -836,7 +848,7 @@ public class WordVectorSerializer {
|
||||||
String line;
|
String line;
|
||||||
while ((line = reader.readLine()) != null) {
|
while ((line = reader.readLine()) != null) {
|
||||||
String[] split = line.split(" ");
|
String[] split = line.split(" ");
|
||||||
VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0]));
|
VocabWord word = w2v.getVocab().tokenFor(ReadHelper.decodeB64(split[0]));
|
||||||
word.setElementFrequency((long) Double.parseDouble(split[1]));
|
word.setElementFrequency((long) Double.parseDouble(split[1]));
|
||||||
word.setSequencesCount((long) Double.parseDouble(split[2]));
|
word.setSequencesCount((long) Double.parseDouble(split[2]));
|
||||||
}
|
}
|
||||||
|
@ -946,7 +958,7 @@ public class WordVectorSerializer {
|
||||||
reader = new BufferedReader(new FileReader(h_points));
|
reader = new BufferedReader(new FileReader(h_points));
|
||||||
while ((line = reader.readLine()) != null) {
|
while ((line = reader.readLine()) != null) {
|
||||||
String[] split = line.split(" ");
|
String[] split = line.split(" ");
|
||||||
VocabWord word = vocab.wordFor(decodeB64(split[0]));
|
VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
|
||||||
List<Integer> points = new ArrayList<>();
|
List<Integer> points = new ArrayList<>();
|
||||||
for (int i = 1; i < split.length; i++) {
|
for (int i = 1; i < split.length; i++) {
|
||||||
points.add(Integer.parseInt(split[i]));
|
points.add(Integer.parseInt(split[i]));
|
||||||
|
@ -960,7 +972,7 @@ public class WordVectorSerializer {
|
||||||
reader = new BufferedReader(new FileReader(h_codes));
|
reader = new BufferedReader(new FileReader(h_codes));
|
||||||
while ((line = reader.readLine()) != null) {
|
while ((line = reader.readLine()) != null) {
|
||||||
String[] split = line.split(" ");
|
String[] split = line.split(" ");
|
||||||
VocabWord word = vocab.wordFor(decodeB64(split[0]));
|
VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
|
||||||
List<Byte> codes = new ArrayList<>();
|
List<Byte> codes = new ArrayList<>();
|
||||||
for (int i = 1; i < split.length; i++) {
|
for (int i = 1; i < split.length; i++) {
|
||||||
codes.add(Byte.parseByte(split[i]));
|
codes.add(Byte.parseByte(split[i]));
|
||||||
|
@ -1704,7 +1716,7 @@ public class WordVectorSerializer {
|
||||||
if (line.isEmpty())
|
if (line.isEmpty())
|
||||||
line = iter.nextLine();
|
line = iter.nextLine();
|
||||||
String[] split = line.split(" ");
|
String[] split = line.split(" ");
|
||||||
String word = decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
|
String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
|
||||||
VocabWord word1 = new VocabWord(1.0, word);
|
VocabWord word1 = new VocabWord(1.0, word);
|
||||||
|
|
||||||
word1.setIndex(cache.numWords());
|
word1.setIndex(cache.numWords());
|
||||||
|
@ -1994,7 +2006,13 @@ public class WordVectorSerializer {
|
||||||
private static final String SYN1_ENTRY = "syn1.bin";
|
private static final String SYN1_ENTRY = "syn1.bin";
|
||||||
private static final String SYN1_NEG_ENTRY = "syn1neg.bin";
|
private static final String SYN1_NEG_ENTRY = "syn1neg.bin";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method saves specified SequenceVectors model to target OutputStream
|
||||||
|
*
|
||||||
|
* @param vectors SequenceVectors model
|
||||||
|
* @param stream Target output stream
|
||||||
|
* @param <T>
|
||||||
|
*/
|
||||||
public static <T extends SequenceElement> void writeSequenceVectors(@NonNull SequenceVectors<T> vectors,
|
public static <T extends SequenceElement> void writeSequenceVectors(@NonNull SequenceVectors<T> vectors,
|
||||||
@NonNull OutputStream stream)
|
@NonNull OutputStream stream)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
@ -2040,7 +2058,13 @@ public class WordVectorSerializer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method loads SequenceVectors from specified file path
|
||||||
|
*
|
||||||
|
* @param path String
|
||||||
|
* @param readExtendedTables boolean
|
||||||
|
* @param <T>
|
||||||
|
*/
|
||||||
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull String path,
|
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull String path,
|
||||||
boolean readExtendedTables)
|
boolean readExtendedTables)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
@ -2050,6 +2074,14 @@ public class WordVectorSerializer {
|
||||||
return vectors;
|
return vectors;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method loads SequenceVectors from specified file path
|
||||||
|
*
|
||||||
|
* @param file File
|
||||||
|
* @param readExtendedTables boolean
|
||||||
|
* @param <T>
|
||||||
|
*/
|
||||||
|
|
||||||
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull File file,
|
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull File file,
|
||||||
boolean readExtendedTables)
|
boolean readExtendedTables)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
@ -2058,6 +2090,13 @@ public class WordVectorSerializer {
|
||||||
return vectors;
|
return vectors;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method loads SequenceVectors from specified input stream
|
||||||
|
*
|
||||||
|
* @param stream InputStream
|
||||||
|
* @param readExtendedTables boolean
|
||||||
|
* @param <T>
|
||||||
|
*/
|
||||||
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull InputStream stream,
|
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull InputStream stream,
|
||||||
boolean readExtendedTables)
|
boolean readExtendedTables)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
@ -2381,6 +2420,12 @@ public class WordVectorSerializer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method loads Word2Vec model from binary file
|
||||||
|
*
|
||||||
|
* @param file File
|
||||||
|
* @return Word2Vec
|
||||||
|
*/
|
||||||
public static Word2Vec readAsBinary(@NonNull File file) {
|
public static Word2Vec readAsBinary(@NonNull File file) {
|
||||||
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
||||||
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
||||||
|
@ -2403,6 +2448,12 @@ public class WordVectorSerializer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method loads Word2Vec model from csv file
|
||||||
|
*
|
||||||
|
* @param file File
|
||||||
|
* @return Word2Vec
|
||||||
|
*/
|
||||||
public static Word2Vec readAsCsv(@NonNull File file) {
|
public static Word2Vec readAsCsv(@NonNull File file) {
|
||||||
|
|
||||||
Word2Vec vec;
|
Word2Vec vec;
|
||||||
|
@ -2491,7 +2542,7 @@ public class WordVectorSerializer {
|
||||||
String line;
|
String line;
|
||||||
while ((line = reader.readLine()) != null) {
|
while ((line = reader.readLine()) != null) {
|
||||||
String[] split = line.split(" ");
|
String[] split = line.split(" ");
|
||||||
VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
|
VocabWord word = new VocabWord(Double.valueOf(split[1]), ReadHelper.decodeB64(split[0]));
|
||||||
word.setIndex(cnt.getAndIncrement());
|
word.setIndex(cnt.getAndIncrement());
|
||||||
word.incrementSequencesCount(Long.valueOf(split[2]));
|
word.incrementSequencesCount(Long.valueOf(split[2]));
|
||||||
|
|
||||||
|
@ -2669,7 +2720,7 @@ public class WordVectorSerializer {
|
||||||
*
|
*
|
||||||
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
|
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
|
||||||
*
|
*
|
||||||
* @param file File should point to previously saved w2v model
|
* @param inputStream InputStream should point to previously saved w2v model
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static WordVectors loadStaticModel(InputStream inputStream) throws IOException {
|
public static WordVectors loadStaticModel(InputStream inputStream) throws IOException {
|
||||||
|
@ -2685,6 +2736,17 @@ public class WordVectorSerializer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this method needs better name :)
|
// TODO: this method needs better name :)
|
||||||
|
/**
|
||||||
|
* This method restores previously saved w2v model. File can be in one of the following formats:
|
||||||
|
* 1) Binary model, either compressed or not. Like well-known Google Model
|
||||||
|
* 2) Popular CSV word2vec text format
|
||||||
|
* 3) DL4j compressed format
|
||||||
|
*
|
||||||
|
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
|
||||||
|
*
|
||||||
|
* @param file File
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static WordVectors loadStaticModel(@NonNull File file) {
|
public static WordVectors loadStaticModel(@NonNull File file) {
|
||||||
if (!file.exists() || file.isDirectory())
|
if (!file.exists() || file.isDirectory())
|
||||||
throw new RuntimeException(
|
throw new RuntimeException(
|
||||||
|
@ -2843,8 +2905,8 @@ public class WordVectorSerializer {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
numWords = Integer.parseInt(readString(stream));
|
numWords = Integer.parseInt(ReadHelper.readString(stream));
|
||||||
vectorLength = Integer.parseInt(readString(stream));
|
vectorLength = Integer.parseInt(ReadHelper.readString(stream));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -2858,13 +2920,13 @@ public class WordVectorSerializer {
|
||||||
@Override
|
@Override
|
||||||
public Pair<VocabWord, float[]> next() {
|
public Pair<VocabWord, float[]> next() {
|
||||||
try {
|
try {
|
||||||
String word = readString(stream);
|
String word = ReadHelper.readString(stream);
|
||||||
VocabWord element = new VocabWord(1.0, word);
|
VocabWord element = new VocabWord(1.0, word);
|
||||||
element.setIndex(idxCounter.getAndIncrement());
|
element.setIndex(idxCounter.getAndIncrement());
|
||||||
|
|
||||||
float[] vector = new float[vectorLength];
|
float[] vector = new float[vectorLength];
|
||||||
for (int i = 0; i < vectorLength; i++) {
|
for (int i = 0; i < vectorLength; i++) {
|
||||||
vector[i] = readFloat(stream);
|
vector[i] = ReadHelper.readFloat(stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Pair.makePair(element, vector);
|
return Pair.makePair(element, vector);
|
||||||
|
@ -2913,7 +2975,7 @@ public class WordVectorSerializer {
|
||||||
|
|
||||||
String[] split = nextLine.split(" ");
|
String[] split = nextLine.split(" ");
|
||||||
|
|
||||||
VocabWord word = new VocabWord(1.0, decodeB64(split[0]));
|
VocabWord word = new VocabWord(1.0, ReadHelper.decodeB64(split[0]));
|
||||||
word.setIndex(idxCounter.getAndIncrement());
|
word.setIndex(idxCounter.getAndIncrement());
|
||||||
|
|
||||||
float[] vector = new float[split.length - 1];
|
float[] vector = new float[split.length - 1];
|
||||||
|
@ -2937,26 +2999,12 @@ public class WordVectorSerializer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String encodeB64(String word) {
|
/**
|
||||||
try {
|
* This method saves Word2Vec model to output stream
|
||||||
return "B64:" + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", "");
|
*
|
||||||
} catch (Exception e) {
|
* @param word2Vec Word2Vec
|
||||||
throw new RuntimeException(e);
|
* @param stream OutputStream
|
||||||
}
|
*/
|
||||||
}
|
|
||||||
|
|
||||||
public static String decodeB64(String word) {
|
|
||||||
if (word.startsWith("B64:")) {
|
|
||||||
String arp = word.replaceFirst("B64:", "");
|
|
||||||
try {
|
|
||||||
return new String(Base64.decodeBase64(arp), "UTF-8");
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
} else
|
|
||||||
return word;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream)
|
public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
|
@ -2968,6 +3016,13 @@ public class WordVectorSerializer {
|
||||||
writeSequenceVectors(vectors, stream);
|
writeSequenceVectors(vectors, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method restores Word2Vec model from file
|
||||||
|
*
|
||||||
|
* @param path String
|
||||||
|
* @param readExtendedTables booleab
|
||||||
|
* @return Word2Vec
|
||||||
|
*/
|
||||||
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
|
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
|
@ -2976,6 +3031,12 @@ public class WordVectorSerializer {
|
||||||
return word2Vec;
|
return word2Vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method saves table of weights to file
|
||||||
|
*
|
||||||
|
* @param weightLookupTable WeightLookupTable
|
||||||
|
* @param file File
|
||||||
|
*/
|
||||||
public static <T extends SequenceElement> void writeLookupTable(WeightLookupTable<T> weightLookupTable,
|
public static <T extends SequenceElement> void writeLookupTable(WeightLookupTable<T> weightLookupTable,
|
||||||
@NonNull File file) throws IOException {
|
@NonNull File file) throws IOException {
|
||||||
try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file),
|
try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file),
|
||||||
|
@ -3038,7 +3099,7 @@ public class WordVectorSerializer {
|
||||||
headerRead = true;
|
headerRead = true;
|
||||||
weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build();
|
weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build();
|
||||||
} else {
|
} else {
|
||||||
String label = decodeB64(tokens[0]);
|
String label = ReadHelper.decodeB64(tokens[0]);
|
||||||
int freq = Integer.parseInt(tokens[1]);
|
int freq = Integer.parseInt(tokens[1]);
|
||||||
int rows = Integer.parseInt(tokens[2]);
|
int rows = Integer.parseInt(tokens[2]);
|
||||||
int cols = Integer.parseInt(tokens[3]);
|
int cols = Integer.parseInt(tokens[3]);
|
||||||
|
@ -3071,6 +3132,13 @@ public class WordVectorSerializer {
|
||||||
return weightLookupTable;
|
return weightLookupTable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method loads Word2Vec model from file
|
||||||
|
*
|
||||||
|
* @param file File
|
||||||
|
* @param readExtendedTables boolean
|
||||||
|
* @return Word2Vec
|
||||||
|
*/
|
||||||
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
|
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
|
@ -3078,6 +3146,13 @@ public class WordVectorSerializer {
|
||||||
return word2Vec;
|
return word2Vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method loads Word2Vec model from input stream
|
||||||
|
*
|
||||||
|
* @param stream InputStream
|
||||||
|
* @param readExtendedTable boolean
|
||||||
|
* @return Word2Vec
|
||||||
|
*/
|
||||||
public static Word2Vec readWord2Vec(@NonNull InputStream stream,
|
public static Word2Vec readWord2Vec(@NonNull InputStream stream,
|
||||||
boolean readExtendedTable) throws IOException {
|
boolean readExtendedTable) throws IOException {
|
||||||
SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable);
|
SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable);
|
||||||
|
@ -3087,7 +3162,13 @@ public class WordVectorSerializer {
|
||||||
word2Vec.setModelUtils(vectors.getModelUtils());
|
word2Vec.setModelUtils(vectors.getModelUtils());
|
||||||
return word2Vec;
|
return word2Vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method loads FastText model to file
|
||||||
|
*
|
||||||
|
* @param vectors FastText
|
||||||
|
* @param path File
|
||||||
|
*/
|
||||||
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
|
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
|
||||||
ObjectOutputStream outputStream = null;
|
ObjectOutputStream outputStream = null;
|
||||||
try {
|
try {
|
||||||
|
@ -3106,6 +3187,11 @@ public class WordVectorSerializer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method unloads FastText model from file
|
||||||
|
*
|
||||||
|
* @param path File
|
||||||
|
*/
|
||||||
public static FastText readWordVectors(File path) {
|
public static FastText readWordVectors(File path) {
|
||||||
FastText result = null;
|
FastText result = null;
|
||||||
try {
|
try {
|
||||||
|
@ -3124,6 +3210,13 @@ public class WordVectorSerializer {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method prints memory usage to log
|
||||||
|
*
|
||||||
|
* @param numWords
|
||||||
|
* @param vectorLength
|
||||||
|
* @param numTables
|
||||||
|
*/
|
||||||
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
|
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
|
||||||
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;
|
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;
|
||||||
|
|
||||||
|
@ -3144,4 +3237,102 @@ public class WordVectorSerializer {
|
||||||
OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx);
|
OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper static methods to read data from input stream.
|
||||||
|
*/
|
||||||
|
public static class ReadHelper {
|
||||||
|
/**
|
||||||
|
* Read a float from a data input stream Credit to:
|
||||||
|
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
||||||
|
*
|
||||||
|
* @param is
|
||||||
|
* @return
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
|
private static float readFloat(InputStream is) throws IOException {
|
||||||
|
byte[] bytes = new byte[4];
|
||||||
|
is.read(bytes);
|
||||||
|
return getFloat(bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read a string from a data input stream Credit to:
|
||||||
|
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
||||||
|
*
|
||||||
|
* @param b
|
||||||
|
* @return
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
|
private static float getFloat(byte[] b) {
|
||||||
|
int accum = 0;
|
||||||
|
accum = accum | (b[0] & 0xff) << 0;
|
||||||
|
accum = accum | (b[1] & 0xff) << 8;
|
||||||
|
accum = accum | (b[2] & 0xff) << 16;
|
||||||
|
accum = accum | (b[3] & 0xff) << 24;
|
||||||
|
return Float.intBitsToFloat(accum);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read a string from a data input stream Credit to:
|
||||||
|
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
||||||
|
*
|
||||||
|
* @param dis
|
||||||
|
* @return
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
|
private static String readString(DataInputStream dis) throws IOException {
|
||||||
|
byte[] bytes = new byte[MAX_SIZE];
|
||||||
|
byte b = dis.readByte();
|
||||||
|
int i = -1;
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
while (b != 32 && b != 10) {
|
||||||
|
i++;
|
||||||
|
bytes[i] = b;
|
||||||
|
b = dis.readByte();
|
||||||
|
if (i == 49) {
|
||||||
|
sb.append(new String(bytes, "UTF-8"));
|
||||||
|
i = -1;
|
||||||
|
bytes = new byte[MAX_SIZE];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.append(new String(bytes, 0, i + 1, "UTF-8"));
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final String B64 = "B64:";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Encode input string
|
||||||
|
*
|
||||||
|
* @param word String
|
||||||
|
* @return String
|
||||||
|
*/
|
||||||
|
public static String encodeB64(String word) {
|
||||||
|
try {
|
||||||
|
return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", "");
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Encode input string
|
||||||
|
*
|
||||||
|
* @param word String
|
||||||
|
* @return String
|
||||||
|
*/
|
||||||
|
|
||||||
|
public static String decodeB64(String word) {
|
||||||
|
if (word.startsWith(B64)) {
|
||||||
|
String arp = word.replaceFirst(B64, "");
|
||||||
|
try {
|
||||||
|
return new String(Base64.decodeBase64(arp), "UTF-8");
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
} else
|
||||||
|
return word;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
|
|
||||||
<artifactId>deeplearning4j-aws_2.11</artifactId>
|
<artifactId>deeplearning4j-aws_2.11</artifactId>
|
||||||
<name>DeepLearning4j-AWS</name>
|
<name>DeepLearning4j-AWS</name>
|
||||||
<version>1.0.0_spark_2-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
<parent>
|
<parent>
|
||||||
<artifactId>spark_2.11</artifactId>
|
<artifactId>spark_2.11</artifactId>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
<version>1.0.0_spark_2-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
</parent>
|
</parent>
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
<parent>
|
<parent>
|
||||||
<artifactId>spark_2.11</artifactId>
|
<artifactId>spark_2.11</artifactId>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
<version>1.0.0_spark_2-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
</parent>
|
</parent>
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<artifactId>dl4j-spark-nlp_2.11</artifactId>
|
<artifactId>dl4j-spark-nlp_2.11</artifactId>
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
<parent>
|
<parent>
|
||||||
<artifactId>spark_2.11</artifactId>
|
<artifactId>spark_2.11</artifactId>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
<version>1.0.0_spark_2-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
</parent>
|
</parent>
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
<parent>
|
<parent>
|
||||||
<artifactId>spark_2.11</artifactId>
|
<artifactId>spark_2.11</artifactId>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
<version>1.0.0_spark_2-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
</parent>
|
</parent>
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<artifactId>dl4j-spark_2.11</artifactId>
|
<artifactId>dl4j-spark_2.11</artifactId>
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package org.deeplearning4j.spark;
|
package org.deeplearning4j.spark;
|
||||||
|
|
||||||
import org.apache.spark.serializer.SerializerInstance;
|
import org.apache.spark.serializer.SerializerInstance;
|
||||||
import org.deeplearning4j.eval.*;
|
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -28,6 +27,9 @@ import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.*;
|
||||||
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
|
|
@ -19,7 +19,6 @@ package org.deeplearning4j.spark.impl.multilayer;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -30,6 +29,7 @@ import org.deeplearning4j.spark.BaseSparkTest;
|
||||||
import org.deeplearning4j.spark.api.TrainingMaster;
|
import org.deeplearning4j.spark.api.TrainingMaster;
|
||||||
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
|
|
@ -29,15 +29,13 @@ import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.mllib.util.MLUtils;
|
import org.apache.spark.mllib.util.MLUtils;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
|
||||||
import org.deeplearning4j.eval.ROC;
|
|
||||||
import org.deeplearning4j.eval.ROCMultiClass;
|
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
|
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
|
||||||
|
@ -56,6 +54,9 @@ import org.junit.Ignore;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
@ -63,6 +64,7 @@ import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
import org.nd4j.linalg.learning.config.RmsProp;
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
|
@ -70,7 +72,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
@ -121,11 +122,6 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
|
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
|
||||||
|
|
||||||
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
|
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
|
||||||
Evaluation evaluation = new Evaluation();
|
|
||||||
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
|
|
||||||
System.out.println(evaluation.stats());
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,20 +133,15 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
.getAbsolutePath())
|
.getAbsolutePath())
|
||||||
.toJavaRDD().map(new TestFn());
|
.toJavaRDD().map(new TestFn());
|
||||||
|
|
||||||
DataSet d = new IrisDataSetIterator(150, 150).next();
|
|
||||||
MultiLayerConfiguration conf =
|
MultiLayerConfiguration conf =
|
||||||
new NeuralNetConfiguration.Builder().seed(123)
|
new NeuralNetConfiguration.Builder().seed(123)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
|
.updater(new Adam(1e-6))
|
||||||
.miniBatch(true).maxNumLineSearchIterations(10)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.list().layer(0,
|
.list()
|
||||||
new DenseLayer.Builder().nIn(4).nOut(100)
|
.layer(new BatchNormalization.Builder().nIn(4).nOut(4).build())
|
||||||
.weightInit(WeightInit.XAVIER)
|
.layer(new DenseLayer.Builder().nIn(4).nOut(32).activation(Activation.RELU).build())
|
||||||
.activation(Activation.RELU)
|
.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(32).nOut(3)
|
||||||
.build())
|
.activation(Activation.SOFTMAX).build())
|
||||||
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
|
|
||||||
LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3)
|
|
||||||
.activation(Activation.SOFTMAX)
|
|
||||||
.weightInit(WeightInit.XAVIER).build())
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,10 +152,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(),
|
SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(),
|
||||||
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
|
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
|
||||||
|
|
||||||
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
|
master.fitLabeledPoint(data);
|
||||||
Evaluation evaluation = new Evaluation();
|
|
||||||
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
|
|
||||||
System.out.println(evaluation.stats());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test(timeout = 120000L)
|
||||||
|
@ -465,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
tempDirF.deleteOnExit();
|
tempDirF.deleteOnExit();
|
||||||
|
|
||||||
int dataSetObjSize = 1;
|
int dataSetObjSize = 1;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 16;
|
||||||
int numSplits = 10;
|
int numSplits = 5;
|
||||||
int averagingFrequency = 3;
|
int averagingFrequency = 3;
|
||||||
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
|
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
|
||||||
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
|
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
</parent>
|
</parent>
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<artifactId>spark_2.11</artifactId>
|
<artifactId>spark_2.11</artifactId>
|
||||||
<version>1.0.0_spark_2-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
|
|
||||||
<name>Spark parent</name>
|
<name>Spark parent</name>
|
||||||
|
@ -36,7 +36,7 @@
|
||||||
<properties>
|
<properties>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
||||||
<datavec.spark.version>1.0.0_spark_2-SNAPSHOT</datavec.spark.version>
|
<datavec.spark.version>1.0.0-SNAPSHOT</datavec.spark.version>
|
||||||
|
|
||||||
|
|
||||||
<scala.macros.version>2.1.0</scala.macros.version>
|
<scala.macros.version>2.1.0</scala.macros.version>
|
||||||
|
|
|
@ -24,11 +24,13 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <array/ConstantDescriptor.h>
|
#include <array/ConstantDescriptor.h>
|
||||||
#include <array/ConstantDataBuffer.h>
|
#include <array/ConstantDataBuffer.h>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
class ConstantHolder {
|
class ConstantHolder {
|
||||||
private:
|
private:
|
||||||
int _deviceId = 0;
|
int _deviceId = 0;
|
||||||
|
std::mutex _mutex;
|
||||||
|
|
||||||
std::map<nd4j::DataType, ConstantDataBuffer> _buffers;
|
std::map<nd4j::DataType, ConstantDataBuffer> _buffers;
|
||||||
public:
|
public:
|
||||||
|
@ -53,6 +55,8 @@ namespace nd4j {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ConstantDataBuffer* getConstantDataBuffer();
|
ConstantDataBuffer* getConstantDataBuffer();
|
||||||
|
|
||||||
|
std::mutex* mutex();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,10 @@ namespace nd4j {
|
||||||
return _buffers.count(dataType) > 0;
|
return _buffers.count(dataType) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::mutex* ConstantHolder::mutex() {
|
||||||
|
return &_mutex;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool ConstantHolder::hasBuffer() {
|
bool ConstantHolder::hasBuffer() {
|
||||||
return hasBuffer(DataTypeUtils::fromT<T>());
|
return hasBuffer(DataTypeUtils::fromT<T>());
|
||||||
|
|
|
@ -47,7 +47,7 @@ namespace nd4j {
|
||||||
|
|
||||||
_currentMutex.unlock();
|
_currentMutex.unlock();
|
||||||
|
|
||||||
setCurrentDevice(globalThreadToDevice);
|
setCurrentNativeDevice(globalThreadToDevice);
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we already know affinity - just return it
|
// if we already know affinity - just return it
|
||||||
|
@ -92,6 +92,8 @@ namespace nd4j {
|
||||||
|
|
||||||
void AffinityManager::setCurrentNativeDevice(int deviceId) {
|
void AffinityManager::setCurrentNativeDevice(int deviceId) {
|
||||||
auto res = cudaSetDevice(deviceId);
|
auto res = cudaSetDevice(deviceId);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("setCurrentDevice failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AffinityManager::setCurrentDevice(int deviceId) {
|
void AffinityManager::setCurrentDevice(int deviceId) {
|
||||||
|
@ -104,17 +106,22 @@ namespace nd4j {
|
||||||
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream());
|
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream());
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw cuda_exception::build("setCurrentDevice -> specialSync failed", res);
|
throw cuda_exception::build("setCurrentDevice -> specialSync failed", res);
|
||||||
|
|
||||||
|
if (deviceId != previousDeviceId) {
|
||||||
|
// discard existing stuff
|
||||||
|
nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", "");
|
||||||
|
LaunchContext::releaseBuffers();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto res = cudaSetDevice(deviceId);
|
if (deviceId != previousDeviceId) {
|
||||||
if (res != 0)
|
auto res = cudaSetDevice(deviceId);
|
||||||
throw cuda_exception::build("cudaSetDevice failed", res);
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("cudaSetDevice failed", res);
|
||||||
|
}
|
||||||
|
|
||||||
// update thread-device affinity
|
// update thread-device affinity
|
||||||
globalThreadToDevice = deviceId;
|
globalThreadToDevice = deviceId;
|
||||||
|
|
||||||
// discard existing stuff
|
|
||||||
LaunchContext::releaseBuffers();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);
|
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);
|
||||||
|
|
|
@ -107,7 +107,6 @@ namespace nd4j {
|
||||||
|
|
||||||
//////
|
//////
|
||||||
_allocated = false;
|
_allocated = false;
|
||||||
_initialized = false;
|
|
||||||
_deviceId = -1;
|
_deviceId = -1;
|
||||||
|
|
||||||
this->_specialStream = nullptr;
|
this->_specialStream = nullptr;
|
||||||
|
@ -116,6 +115,8 @@ namespace nd4j {
|
||||||
this->_reductionPointer = nullptr;
|
this->_reductionPointer = nullptr;
|
||||||
this->_scalarPointer = nullptr;
|
this->_scalarPointer = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_initialized = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ContextBuffers::~ContextBuffers() {
|
ContextBuffers::~ContextBuffers() {
|
||||||
|
@ -163,21 +164,21 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::reductionBuffer() {
|
void* ContextBuffers::reductionBuffer() {
|
||||||
if (_reductionPointer == nullptr)
|
if (!_initialized)
|
||||||
initialize();
|
initialize();
|
||||||
|
|
||||||
return _reductionPointer;
|
return _reductionPointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::scalarBuffer() {
|
void* ContextBuffers::scalarBuffer() {
|
||||||
if (_scalarPointer == nullptr)
|
if (!_initialized)
|
||||||
initialize();
|
initialize();
|
||||||
|
|
||||||
return _scalarPointer;
|
return _scalarPointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::allocationBuffer() {
|
void* ContextBuffers::allocationBuffer() {
|
||||||
if (_allocationPointer == nullptr)
|
if (!_initialized)
|
||||||
initialize();
|
initialize();
|
||||||
|
|
||||||
return _allocationPointer;
|
return _allocationPointer;
|
||||||
|
@ -204,15 +205,23 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::execStream() {
|
void* ContextBuffers::execStream() {
|
||||||
if (_execStream == nullptr)
|
if (!_initialized) {
|
||||||
|
//nd4j_printf("execStream not initialized\n", "");
|
||||||
initialize();
|
initialize();
|
||||||
|
} else {
|
||||||
|
//nd4j_printf("execStream is initialized\n", "");
|
||||||
|
}
|
||||||
|
|
||||||
return _execStream;
|
return _execStream;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::specialStream() {
|
void* ContextBuffers::specialStream() {
|
||||||
if (_specialStream == nullptr)
|
if (!_initialized) {
|
||||||
|
//nd4j_printf("specialStream not initialized\n", "");
|
||||||
initialize();
|
initialize();
|
||||||
|
} else {
|
||||||
|
//nd4j_printf("specialStream is initialized\n", "");
|
||||||
|
}
|
||||||
|
|
||||||
return _specialStream;
|
return _specialStream;
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,10 +57,6 @@ LaunchContext::LaunchContext() {
|
||||||
_deviceID = 0;
|
_deviceID = 0;
|
||||||
|
|
||||||
_isAllocated = true;
|
_isAllocated = true;
|
||||||
|
|
||||||
_cublasHandle = CublasHelper::getInstance()->handle();
|
|
||||||
|
|
||||||
_cusolverHandle = CublasHelper::getInstance()->solver();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
||||||
|
@ -89,13 +85,13 @@ LaunchContext::LaunchContext() {
|
||||||
|
|
||||||
_contexts.resize(numDevices);
|
_contexts.resize(numDevices);
|
||||||
for (int e = 0; e < numDevices; e++) {
|
for (int e = 0; e < numDevices; e++) {
|
||||||
AffinityManager::setCurrentDevice(e);
|
AffinityManager::setCurrentNativeDevice(e);
|
||||||
|
|
||||||
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
|
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't forget to restore device back again
|
// don't forget to restore device back again
|
||||||
AffinityManager::setCurrentDevice(deviceId);
|
AffinityManager::setCurrentNativeDevice(deviceId);
|
||||||
}
|
}
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
|
@ -117,11 +113,11 @@ LaunchContext::LaunchContext() {
|
||||||
};
|
};
|
||||||
|
|
||||||
void* LaunchContext::getCublasHandle() const {
|
void* LaunchContext::getCublasHandle() const {
|
||||||
return _cublasHandle;
|
return CublasHelper::getInstance()->handle();
|
||||||
};
|
};
|
||||||
|
|
||||||
void* LaunchContext::getCusolverHandle() const {
|
void* LaunchContext::getCusolverHandle() const {
|
||||||
return _cusolverHandle;
|
return CublasHelper::getInstance()->solver();
|
||||||
};
|
};
|
||||||
|
|
||||||
cudaStream_t* LaunchContext::getCudaStream() const {
|
cudaStream_t* LaunchContext::getCudaStream() const {
|
||||||
|
@ -162,6 +158,7 @@ LaunchContext::LaunchContext() {
|
||||||
};
|
};
|
||||||
|
|
||||||
void LaunchContext::releaseBuffers() {
|
void LaunchContext::releaseBuffers() {
|
||||||
|
nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", "");
|
||||||
contextBuffers.release();
|
contextBuffers.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,12 +38,13 @@ namespace nd4j {
|
||||||
static ConstantHelper* _INSTANCE;
|
static ConstantHelper* _INSTANCE;
|
||||||
ConstantHelper();
|
ConstantHelper();
|
||||||
|
|
||||||
std::vector<std::map<ConstantDescriptor, ConstantHolder>> _cache;
|
std::vector<std::map<ConstantDescriptor, ConstantHolder*>> _cache;
|
||||||
|
|
||||||
// tracking of per-device constant memory buffers (CUDA only atm)
|
// tracking of per-device constant memory buffers (CUDA only atm)
|
||||||
std::vector<Nd4jPointer> _devicePointers;
|
std::vector<Nd4jPointer> _devicePointers;
|
||||||
std::vector<Nd4jLong> _deviceOffsets;
|
std::vector<Nd4jLong> _deviceOffsets;
|
||||||
std::mutex _mutex;
|
std::mutex _mutex;
|
||||||
|
std::mutex _mutexHolder;
|
||||||
|
|
||||||
std::vector<Nd4jLong> _counters;
|
std::vector<Nd4jLong> _counters;
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -48,10 +48,10 @@ namespace nd4j {
|
||||||
static ConstantShapeHelper* getInstance();
|
static ConstantShapeHelper* getInstance();
|
||||||
|
|
||||||
|
|
||||||
ConstantDataBuffer& bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
|
ConstantDataBuffer bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
|
||||||
ConstantDataBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor);
|
ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor);
|
||||||
ConstantDataBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo);
|
ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo);
|
||||||
ConstantDataBuffer& bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
|
ConstantDataBuffer bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
|
||||||
|
|
||||||
|
|
||||||
Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType);
|
Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType);
|
||||||
|
|
|
@ -54,11 +54,11 @@ namespace nd4j {
|
||||||
* @param keepUnitiesInShape
|
* @param keepUnitiesInShape
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
TadPack tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
TadPack tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
TadPack tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(TadDescriptor &descriptor);
|
TadPack tadForDimensions(TadDescriptor &descriptor);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns number of cached TAD shapes/offsets on specific device
|
* This method returns number of cached TAD shapes/offsets on specific device
|
||||||
|
|
|
@ -33,7 +33,8 @@ namespace nd4j {
|
||||||
_cache.resize(numDevices);
|
_cache.resize(numDevices);
|
||||||
_counters.resize(numDevices);
|
_counters.resize(numDevices);
|
||||||
for (int e = 0; e < numDevices; e++) {
|
for (int e = 0; e < numDevices; e++) {
|
||||||
std::map<ConstantDescriptor, ConstantHolder> map;
|
std::map<ConstantDescriptor, ConstantHolder*> map;
|
||||||
|
|
||||||
_cache[e] = map;
|
_cache[e] = map;
|
||||||
_counters[e] = 0L;
|
_counters[e] = 0L;
|
||||||
}
|
}
|
||||||
|
@ -70,15 +71,26 @@ namespace nd4j {
|
||||||
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
||||||
const auto deviceId = getCurrentDevice();
|
const auto deviceId = getCurrentDevice();
|
||||||
|
|
||||||
|
// we're locking away cache modification
|
||||||
|
_mutexHolder.lock();
|
||||||
|
|
||||||
if (_cache[deviceId].count(descriptor) == 0) {
|
if (_cache[deviceId].count(descriptor) == 0) {
|
||||||
ConstantHolder holder;
|
_cache[deviceId][descriptor] = new ConstantHolder();
|
||||||
_cache[deviceId][descriptor] = holder;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantHolder* holder = &_cache[deviceId][descriptor];
|
auto holder = _cache[deviceId][descriptor];
|
||||||
|
|
||||||
|
// releasing cache lock
|
||||||
|
_mutexHolder.unlock();
|
||||||
|
|
||||||
|
|
||||||
|
ConstantDataBuffer* result;
|
||||||
|
|
||||||
|
// access to this holder instance is synchronous
|
||||||
|
holder->mutex()->lock();
|
||||||
|
|
||||||
if (holder->hasBuffer(dataType))
|
if (holder->hasBuffer(dataType))
|
||||||
return holder->getConstantDataBuffer(dataType);
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
else {
|
else {
|
||||||
auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
||||||
auto cbuff = new int8_t[size];
|
auto cbuff = new int8_t[size];
|
||||||
|
@ -94,8 +106,11 @@ namespace nd4j {
|
||||||
ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
||||||
holder->addBuffer(dataBuffer, dataType);
|
holder->addBuffer(dataBuffer, dataType);
|
||||||
|
|
||||||
return holder->getConstantDataBuffer(dataType);
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
}
|
}
|
||||||
|
holder->mutex()->unlock();
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {
|
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {
|
||||||
|
|
|
@ -41,18 +41,18 @@ namespace nd4j {
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||||
ShapeDescriptor descriptor(dataType, order, shape);
|
ShapeDescriptor descriptor(dataType, order, shape);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||||
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||||
int deviceId = 0;
|
int deviceId = 0;
|
||||||
|
|
||||||
_mutex.lock();
|
_mutex.lock();
|
||||||
|
@ -62,19 +62,19 @@ namespace nd4j {
|
||||||
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
|
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
|
||||||
ShapeDescriptor descriptor1(descriptor);
|
ShapeDescriptor descriptor1(descriptor);
|
||||||
_cache[deviceId][descriptor1] = buffer;
|
_cache[deviceId][descriptor1] = buffer;
|
||||||
ConstantDataBuffer &r = _cache[deviceId][descriptor1];
|
auto r = _cache[deviceId][descriptor1];
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
ConstantDataBuffer &r = _cache[deviceId].at(descriptor);
|
auto r = _cache[deviceId].at(descriptor);
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||||
ShapeDescriptor descriptor(shapeInfo);
|
ShapeDescriptor descriptor(shapeInfo);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,25 +38,25 @@ namespace nd4j {
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||||
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
||||||
return tadForDimensions(tadDescriptor);
|
return tadForDimensions(tadDescriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||||
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
|
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
|
||||||
return tadForDimensions(tadDescriptor);
|
return tadForDimensions(tadDescriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||||
const int deviceId = 0;
|
const int deviceId = 0;
|
||||||
|
|
||||||
_mutex.lock();
|
_mutex.lock();
|
||||||
|
@ -105,7 +105,7 @@ namespace nd4j {
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
TadPack &r = _cache[deviceId][descriptor];
|
TadPack r = _cache[deviceId][descriptor];
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
|
|
|
@ -24,11 +24,13 @@
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
#include <pointercast.h>
|
#include <pointercast.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
class CublasHelper {
|
class CublasHelper {
|
||||||
private:
|
private:
|
||||||
static CublasHelper *_INSTANCE;
|
static CublasHelper *_INSTANCE;
|
||||||
|
static std::mutex _mutex;
|
||||||
|
|
||||||
std::vector<void*> _cache;
|
std::vector<void*> _cache;
|
||||||
std::vector<void*> _solvers;
|
std::vector<void*> _solvers;
|
||||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
||||||
throw cuda_exception::build("cudaSetDevice failed", res);
|
throw cuda_exception::build("cudaSetDevice failed", res);
|
||||||
auto constant = getConstantSpace();
|
auto constant = getConstantSpace();
|
||||||
|
|
||||||
std::map<ConstantDescriptor, ConstantHolder> devCache;
|
std::map<ConstantDescriptor, ConstantHolder*> devCache;
|
||||||
|
|
||||||
_devicePointers[e] = constant;
|
_devicePointers[e] = constant;
|
||||||
_deviceOffsets[e] = 0;
|
_deviceOffsets[e] = 0;
|
||||||
|
@ -136,15 +136,24 @@ namespace nd4j {
|
||||||
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
||||||
const auto deviceId = getCurrentDevice();
|
const auto deviceId = getCurrentDevice();
|
||||||
|
|
||||||
if (_cache[deviceId].count(descriptor) == 0) {
|
// all cache modifications are synchronous
|
||||||
ConstantHolder holder;
|
_mutexHolder.lock();
|
||||||
_cache[deviceId][descriptor] = holder;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConstantHolder* holder = &_cache[deviceId][descriptor];
|
if (_cache[deviceId].count(descriptor) == 0) {
|
||||||
|
_cache[deviceId][descriptor] = new ConstantHolder();
|
||||||
|
}
|
||||||
|
auto holder = _cache[deviceId][descriptor];
|
||||||
|
|
||||||
|
// release cache lock
|
||||||
|
_mutexHolder.unlock();
|
||||||
|
|
||||||
|
ConstantDataBuffer* result;
|
||||||
|
|
||||||
|
// access to this holder instance is synchronous
|
||||||
|
holder->mutex()->lock();
|
||||||
|
|
||||||
if (holder->hasBuffer(dataType)) {
|
if (holder->hasBuffer(dataType)) {
|
||||||
return holder->getConstantDataBuffer(dataType);
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
} else {
|
} else {
|
||||||
auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
||||||
auto cbuff = new int8_t[numBytes];
|
auto cbuff = new int8_t[numBytes];
|
||||||
|
@ -160,10 +169,14 @@ namespace nd4j {
|
||||||
auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType));
|
auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType));
|
||||||
|
|
||||||
ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
||||||
holder->addBuffer(dataBuffer, dataType);
|
|
||||||
|
|
||||||
return holder->getConstantDataBuffer(dataType);
|
holder->addBuffer(dataBuffer, dataType);
|
||||||
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
}
|
}
|
||||||
|
// release holder lock
|
||||||
|
holder->mutex()->unlock();
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {
|
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {
|
||||||
|
|
|
@ -44,17 +44,17 @@ namespace nd4j {
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||||
ShapeDescriptor descriptor(dataType, order, shape);
|
ShapeDescriptor descriptor(dataType, order, shape);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||||
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||||
int deviceId = AffinityManager::currentDeviceId();
|
int deviceId = AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
_mutex.lock();
|
_mutex.lock();
|
||||||
|
@ -65,19 +65,19 @@ namespace nd4j {
|
||||||
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
|
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
|
||||||
ShapeDescriptor descriptor1(descriptor);
|
ShapeDescriptor descriptor1(descriptor);
|
||||||
_cache[deviceId][descriptor1] = buffer;
|
_cache[deviceId][descriptor1] = buffer;
|
||||||
ConstantDataBuffer &r = _cache[deviceId][descriptor1];
|
auto r = _cache[deviceId][descriptor1];
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
ConstantDataBuffer &r = _cache[deviceId].at(descriptor);
|
ConstantDataBuffer r = _cache[deviceId].at(descriptor);
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||||
ShapeDescriptor descriptor(shapeInfo);
|
ShapeDescriptor descriptor(shapeInfo);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,25 +43,25 @@ namespace nd4j {
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||||
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
||||||
return tadForDimensions(tadDescriptor);
|
return tadForDimensions(tadDescriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||||
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
|
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
|
||||||
return tadForDimensions(tadDescriptor);
|
return tadForDimensions(tadDescriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||||
const int deviceId = AffinityManager::currentDeviceId();
|
const int deviceId = AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
_mutex.lock();
|
_mutex.lock();
|
||||||
|
@ -96,14 +96,14 @@ namespace nd4j {
|
||||||
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
|
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
|
||||||
_cache[deviceId][descriptor] = t;
|
_cache[deviceId][descriptor] = t;
|
||||||
|
|
||||||
TadPack &r = _cache[deviceId][descriptor];
|
TadPack r = _cache[deviceId][descriptor];
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
delete[] shapeInfo;
|
delete[] shapeInfo;
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
TadPack &r = _cache[deviceId][descriptor];
|
TadPack r = _cache[deviceId][descriptor];
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include <execution/AffinityManager.h>
|
#include <execution/AffinityManager.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
std::mutex CublasHelper::_mutex;
|
||||||
|
|
||||||
static void* handle_() {
|
static void* handle_() {
|
||||||
auto _handle = new cublasHandle_t();
|
auto _handle = new cublasHandle_t();
|
||||||
|
@ -56,22 +57,24 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasHelper::CublasHelper() {
|
CublasHelper::CublasHelper() {
|
||||||
|
//nd4j_printf("Initializing cuBLAS\n","");
|
||||||
auto numDevices = AffinityManager::numberOfDevices();
|
auto numDevices = AffinityManager::numberOfDevices();
|
||||||
auto currentDevice = AffinityManager::currentDeviceId();
|
auto currentDevice = AffinityManager::currentDeviceId();
|
||||||
_cache.resize(numDevices);
|
_cache.resize(numDevices);
|
||||||
_solvers.resize(numDevices);
|
_solvers.resize(numDevices);
|
||||||
for (int e = 0; e < numDevices; e++) {
|
for (int e = 0; e < numDevices; e++) {
|
||||||
AffinityManager::setCurrentDevice(e);
|
AffinityManager::setCurrentNativeDevice(e);
|
||||||
|
|
||||||
_cache[e] = handle_();
|
_cache[e] = handle_();
|
||||||
_solvers[e] = solver_();
|
_solvers[e] = solver_();
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't forget to restore back original device
|
// don't forget to restore back original device
|
||||||
AffinityManager::setCurrentDevice(currentDevice);
|
AffinityManager::setCurrentNativeDevice(currentDevice);
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasHelper::~CublasHelper() {
|
CublasHelper::~CublasHelper() {
|
||||||
|
nd4j_printf("Releasing cuBLAS\n","");
|
||||||
auto numDevices = AffinityManager::numberOfDevices();
|
auto numDevices = AffinityManager::numberOfDevices();
|
||||||
|
|
||||||
for (int e = 0; e < numDevices; e++)
|
for (int e = 0; e < numDevices; e++)
|
||||||
|
@ -79,8 +82,10 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasHelper* CublasHelper::getInstance() {
|
CublasHelper* CublasHelper::getInstance() {
|
||||||
|
_mutex.lock();
|
||||||
if (!_INSTANCE)
|
if (!_INSTANCE)
|
||||||
_INSTANCE = new nd4j::CublasHelper();
|
_INSTANCE = new nd4j::CublasHelper();
|
||||||
|
_mutex.unlock();
|
||||||
|
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
|
@ -162,9 +162,6 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// update rng state
|
|
||||||
rng->rewindH(length);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -223,8 +220,6 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// update rng state
|
|
||||||
rng->rewindH(length);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -256,9 +251,6 @@ namespace functions {
|
||||||
z[offset] = OpClass::op(i+threadOffset, length, rng, extraArguments);
|
z[offset] = OpClass::op(i+threadOffset, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// update rng state
|
|
||||||
rng->rewindH(length);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X>
|
template<typename X>
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 07.12.2017
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
|
@ -38,10 +38,9 @@ CONFIGURABLE_OP_IMPL(matrix_set_diag, 2, 1, false, 0, 0) {
|
||||||
for(int i = 0; i < diagonal->rankOf() - 1; ++i)
|
for(int i = 0; i < diagonal->rankOf() - 1; ++i)
|
||||||
REQUIRE_TRUE(diagonal->sizeAt(i) == input->sizeAt(i), 0, "MATRIX_SET_DIAG op: the shapes of diagonal and input arrays must be equal till last diagonal dimension but one, however got diagonal=%s and input=%s instead !", ShapeUtils::shapeAsString(diagonal).c_str(), ShapeUtils::shapeAsString(input).c_str());
|
REQUIRE_TRUE(diagonal->sizeAt(i) == input->sizeAt(i), 0, "MATRIX_SET_DIAG op: the shapes of diagonal and input arrays must be equal till last diagonal dimension but one, however got diagonal=%s and input=%s instead !", ShapeUtils::shapeAsString(diagonal).c_str(), ShapeUtils::shapeAsString(input).c_str());
|
||||||
|
|
||||||
REQUIRE_TRUE(diagonal->sizeAt(-1) == (int)nd4j::math::nd4j_min<Nd4jLong>(input->sizeAt(-1), input->sizeAt(-2)),
|
REQUIRE_TRUE(diagonal->sizeAt(-1) == (int)nd4j::math::nd4j_min<Nd4jLong>(input->sizeAt(-1), input->sizeAt(-2)), 0, "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but got %i instead !", input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1));
|
||||||
0, "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but got %i instead !", input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1));
|
|
||||||
|
|
||||||
helpers::matrixSetDiag(block.launchContext(), input, diagonal, output);
|
helpers::matrixSetDiag(block.launchContext(), *input, *diagonal, *output, false);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,49 +15,53 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created to use with batched tensor by GS <sgazeos@gmail.com> 3/21/2018
|
// @author GS <sgazeos@gmail.com> 3/21/2018
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <ops/declarable/helpers/matrix_diag.h>
|
#include <ops/declarable/helpers/matrixSetDiag.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(matrix_diag, 1, 1, false, 0, 0) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(!input->isScalar(), 0, "CUSTOM_OP matrix_diag: input array must be at list a vector, but scalar was given!");
|
CUSTOM_OP_IMPL(matrix_diag, 1, 1, false, 0, 0) {
|
||||||
|
|
||||||
output->nullify();
|
auto diagonal = INPUT_VARIABLE(0);
|
||||||
return helpers::matrixDiag(block.launchContext(), input, output);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(matrix_diag) {
|
REQUIRE_TRUE(!diagonal->isScalar(), 0, "CUSTOM_OP matrix_diag: input diagonal array must be at list a vector, but scalar was given!");
|
||||||
Nd4jLong* outShapeInfo = nullptr;
|
|
||||||
auto in = inputShape->at(0);
|
|
||||||
int inRank = shape::rank(in);
|
|
||||||
|
|
||||||
int outRank = inRank + 1;
|
helpers::matrixSetDiag(block.launchContext(), *output, *diagonal, *output, true);
|
||||||
auto lastDimension = shape::sizeAt(in, -1);
|
|
||||||
|
|
||||||
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);
|
return Status::OK();
|
||||||
outShapeInfo[0] = outRank;
|
}
|
||||||
for(int i = 0; i < inRank; ++i)
|
|
||||||
outShapeInfo[i + 1] = shape::sizeAt(in, i);
|
|
||||||
outShapeInfo[outRank] = lastDimension;
|
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in));
|
DECLARE_SHAPE_FN(matrix_diag) {
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(outShapeInfo));
|
Nd4jLong* outShapeInfo = nullptr;
|
||||||
}
|
auto in = inputShape->at(0);
|
||||||
|
int inRank = shape::rank(in);
|
||||||
|
|
||||||
DECLARE_TYPES(matrix_diag) {
|
int outRank = inRank + 1;
|
||||||
getOpDescriptor()
|
auto lastDimension = shape::sizeAt(in, -1);
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setSameMode(true);
|
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);
|
||||||
}
|
outShapeInfo[0] = outRank;
|
||||||
|
for(int i = 0; i < inRank; ++i)
|
||||||
|
outShapeInfo[i + 1] = shape::sizeAt(in, i);
|
||||||
|
outShapeInfo[outRank] = lastDimension;
|
||||||
|
|
||||||
|
ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in));
|
||||||
|
|
||||||
|
return SHAPELIST(CONSTANT(outShapeInfo));
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(matrix_diag) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -76,8 +76,20 @@ namespace nd4j {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a batched matrix tensor with new batched diagonal values.
|
* Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
|
||||||
*/
|
*
|
||||||
|
* Input arrays:
|
||||||
|
* input: input array, considered as batch of matrices
|
||||||
|
* diagonal: array containing elements to be inserted into input array,
|
||||||
|
* following rank condition should be satisfied: diagonal_rank = input_rank - 1,
|
||||||
|
* the shapes of diagonal and input arrays must be equal except last dimension of input array,
|
||||||
|
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
|
||||||
|
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
|
||||||
|
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
|
||||||
|
*
|
||||||
|
* Output array:
|
||||||
|
* has the same shape as input, corresponding diagonal elements are substituted
|
||||||
|
*/
|
||||||
#if NOT_EXCLUDED(OP_matrix_set_diag)
|
#if NOT_EXCLUDED(OP_matrix_set_diag)
|
||||||
DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0);
|
DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -2411,7 +2411,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
||||||
for (Nd4jLong kd = dstart; kd < dend; kd += iStep2)
|
for (Nd4jLong kd = dstart; kd < dend; kd += iStep2)
|
||||||
for (Nd4jLong kh = hstart; kh < hend; kh += iStep3)
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep3)
|
||||||
for (Nd4jLong kw = wstart; kw < wend; kw += iStep4)
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep4)
|
||||||
pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kd + kh + kw]), extraParam0 - (T)1.f);
|
pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(pIn[kd + kh + kw]), extraParam0 - (T)1.f) * nd4j::math::nd4j_sgn<T,T>(pIn[kd + kh + kw]);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by Yurii Shyrma on 07.12.2017.
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "ResultSet.h"
|
#include "ResultSet.h"
|
||||||
|
@ -27,31 +27,48 @@ namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// Returns a batched matrix tensor with new batched diagonal values.
|
template<typename T>
|
||||||
// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag
|
void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) {
|
||||||
template <typename T>
|
|
||||||
static void _matrixSetDiag(const NDArray* input, const NDArray* diagonal, NDArray* output) {
|
|
||||||
|
|
||||||
*output = *input;
|
// input and output are the same array (x == z) when zeroPad = true
|
||||||
|
// xRank = zRank, xRank = yRank + 1
|
||||||
|
// xLen = zLen
|
||||||
|
|
||||||
const int lastDimSize = input->sizeAt(-1);
|
const T* x = input.bufferAsT<T>();
|
||||||
const int last2DimSize = input->sizeAt(-1) * input->sizeAt(-2);
|
const T* y = diagonal.bufferAsT<T>();
|
||||||
const int lastSmallDim = diagonal->sizeAt(-1);
|
T* z = output.bufferAsT<T>();
|
||||||
const int batchSize = input->lengthOf()/last2DimSize;
|
|
||||||
|
|
||||||
for(int i = 0; i < batchSize; ++i )
|
const Nd4jLong* xShapeInfo = input.getShapeInfo();
|
||||||
for(int j = 0; j < lastSmallDim; ++j) {
|
const Nd4jLong* yShapeInfo = diagonal.getShapeInfo();
|
||||||
output->p(i*last2DimSize + j*(lastDimSize + 1), diagonal->e<T>(i*lastSmallDim + j));
|
const Nd4jLong* zShapeInfo = output.getShapeInfo();
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
const bool areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not
|
||||||
|
|
||||||
|
const int xRank = input.rankOf();
|
||||||
|
const auto xLen = input.lengthOf();
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> coords(xRank); // we use the same coordinates storage both for input and output since their ranks are the same
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(coords))
|
||||||
|
for (Nd4jLong i = 0; i < xLen; ++i) {
|
||||||
|
|
||||||
|
shape::index2coords(xRank, xShapeInfo + 1, i, xLen, coords.data());
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords.data(), xRank);
|
||||||
|
const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords.data(), xRank);
|
||||||
|
|
||||||
|
// condition to be on diagonal of innermost matrix
|
||||||
|
if(coords[xRank - 2] == coords[xRank - 1])
|
||||||
|
z[zOffset] = y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank, coords.data(), xRank - 1)];
|
||||||
|
else
|
||||||
|
z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) {
|
//////////////////////////////////////////////////////////////////////////
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), _matrixSetDiag, (input, diagonal, output), LIBND4J_TYPES);
|
void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) {
|
||||||
}
|
BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiag_, (input, diagonal, output, zeroPad), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void _matrixSetDiag, (const NDArray* input, const NDArray* diagonal, NDArray* output), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// Created by GS <sgazeos@gmail.com> on 3/21/2018.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "ResultSet.h"
|
|
||||||
#include <ops/declarable/helpers/matrix_diag.h>
|
|
||||||
#include <Status.h>
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace helpers {
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
// Returns a batched matrix tensor with new batched diagonal values.
|
|
||||||
// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag
|
|
||||||
template <typename T>
|
|
||||||
static int _matrixDiag(const NDArray* input, NDArray* output) {
|
|
||||||
|
|
||||||
auto listOut = output->allTensorsAlongDimension({output->rankOf() - 2, output->rankOf() - 1});
|
|
||||||
auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 1});
|
|
||||||
|
|
||||||
if (listOut->size() != listDiag->size()) {
|
|
||||||
nd4j_printf("matrix_diag: Input matrix has wrong shape.", "");
|
|
||||||
return ND4J_STATUS_VALIDATION;
|
|
||||||
}
|
|
||||||
int lastDimension = input->sizeAt(-1);
|
|
||||||
// TODO: tune this properlys
|
|
||||||
int lO = listOut->size();
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(lO > Environment::getInstance()->tadThreshold())
|
|
||||||
for(int i = 0; i < lO; ++i)
|
|
||||||
for (int e = 0; e < lastDimension; e++)
|
|
||||||
listOut->at(i)->p(e, e, listDiag->at(i)->e<T>(e));
|
|
||||||
|
|
||||||
delete listOut;
|
|
||||||
delete listDiag;
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
int matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) {
|
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiag, (input, output), LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int _matrixDiag, (const NDArray* input, NDArray* output), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -957,9 +957,13 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
|
||||||
|
|
||||||
val *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1.f - extraParam0) / extraParam0);
|
val *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1.f - extraParam0) / extraParam0);
|
||||||
|
|
||||||
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH)
|
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {
|
||||||
for (coords[3] = wstart; coords[3] < wend; coords[3] += dW)
|
for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) {
|
||||||
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], val * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)]), extraParam0 - 1.f));
|
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
|
||||||
|
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[zOffset], val * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(x[xOffset]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn<T,T>(x[xOffset]));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -1123,10 +1127,15 @@ __global__ static void pooling3dBPCuda(const void* vx, const Nd4jLong* xShapeInf
|
||||||
|
|
||||||
val *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1.f - extraParam0) / extraParam0);
|
val *= nd4j::math::nd4j_pow<T,T,T>(sum, ((T)1.f - extraParam0) / extraParam0);
|
||||||
|
|
||||||
for (coords[2] = dstart; coords[2] < dend; coords[2] += dD)
|
for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) {
|
||||||
for (coords[3] = hstart; coords[3] < hend; coords[3] += dH)
|
for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) {
|
||||||
for (coords[4] = wstart; coords[4] < wend; coords[4] += dW)
|
for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) {
|
||||||
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], val * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)]), extraParam0 - 1.f));
|
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
|
||||||
|
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[zOffset], val * nd4j::math::nd4j_pow<T,T,T>(nd4j::math::nd4j_abs<T>(x[xOffset]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn<T,T>(x[xOffset]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,63 +15,87 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by Yurii Shyrma on 07.12.2017.
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "ResultSet.h"
|
#include "ResultSet.h"
|
||||||
#include <ops/declarable/helpers/matrixSetDiag.h>
|
#include <ops/declarable/helpers/matrixSetDiag.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void matrixSetDiagCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) {
|
||||||
|
|
||||||
template <typename T>
|
// x - input, shape [A,B,C]
|
||||||
static __global__ void matrixSetDiagKernel(void* outputBuffer, Nd4jLong* outputShape, void const* diagonalBuffer, Nd4jLong* diagonalShape, Nd4jLong lastDimSize, Nd4jLong last2DimSize, Nd4jLong lastSmallDim, Nd4jLong batchSize) {
|
// y - diagonal, shape [A,B]
|
||||||
__shared__ T* z;
|
// z - output, shape [A,B,C]
|
||||||
__shared__ T const* x;
|
// input and output are the same array (x == z) when zeroPad = true
|
||||||
__shared__ Nd4jLong outLength, diagonalLen;
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
z = reinterpret_cast<T*>(outputBuffer);
|
|
||||||
x = reinterpret_cast<T const*>(diagonalBuffer);
|
|
||||||
outLength = shape::length(outputShape);
|
|
||||||
diagonalLen = shape::length(diagonalShape);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
for(int i = blockIdx.x; i < batchSize; i+= gridDim.x )
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
for(int j = threadIdx.x; j < lastSmallDim; j += blockDim.x) {
|
const auto y = reinterpret_cast<const T*>(vy);
|
||||||
// z[i * last2DimSize + j * (lastDimSize + 1)] = x[i * lastSmallDim + j];
|
auto z = reinterpret_cast<T*>(vz);
|
||||||
z[shape::getIndexOffset(i * last2DimSize + j * (lastDimSize + 1), outputShape, outLength)] = x[shape::getIndexOffset(i * lastSmallDim + j, diagonalShape, diagonalLen)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
// Returns a batched matrix tensor with new batched diagonal values.
|
|
||||||
// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag
|
|
||||||
template <typename T>
|
|
||||||
static void _matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) {
|
|
||||||
*output = *input;
|
|
||||||
|
|
||||||
const int lastDimSize = input->sizeAt(-1);
|
__shared__ int xRank; // xRank = zRank, xRank = yRank + 1
|
||||||
const int last2DimSize = input->sizeAt(-1) * input->sizeAt(-2);
|
__shared__ Nd4jLong xLen, *sharedMem; // xLen = zLen
|
||||||
const int lastSmallDim = diagonal->sizeAt(-1);
|
__shared__ bool areSameOffsets;
|
||||||
const int batchSize = input->lengthOf()/last2DimSize;
|
|
||||||
auto stream = context->getCudaStream();
|
|
||||||
dim3 launchDims(256, 512, 8192);
|
|
||||||
matrixSetDiagKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), diagonal->getSpecialBuffer(), diagonal->getSpecialShapeInfo(), lastDimSize, last2DimSize, lastSmallDim, batchSize);
|
|
||||||
//// #pragma omp parallel for if(batchSize > Environment::getInstance()->elementwiseThreshold()) schedule(static)
|
|
||||||
// for(int i = 0; i < batchSize; ++i )
|
|
||||||
// for(int j = 0; j < lastSmallDim; ++j) {
|
|
||||||
// output->p(i*last2DimSize + j*(lastDimSize + 1), diagonal->e<T>(i*lastSmallDim + j));
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not
|
||||||
|
|
||||||
|
xRank = shape::rank(xShapeInfo);
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) {
|
__syncthreads();
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), _matrixSetDiag, (context, input, diagonal, output), LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void _matrixSetDiag, (nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output), LIBND4J_TYPES);
|
auto coords = sharedMem + threadIdx.x * xRank; // we provide (xRank * sizeof(Nd4jLong) * threadIdx.x) amount of shared memory per each thread
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (Nd4jLong i = tid; i < xLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
shape::index2coords(xRank, xShapeInfo + 1, i, xLen, coords);
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords, xRank);
|
||||||
|
const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords, xRank);
|
||||||
|
|
||||||
|
// condition to be on diagonal of innermost matrix
|
||||||
|
if(coords[xRank - 2] == coords[xRank - 1])
|
||||||
|
z[zOffset] = y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank, coords, xRank - 1)];
|
||||||
|
else
|
||||||
|
z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void matrixSetDiagCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) {
|
||||||
|
|
||||||
|
matrixSetDiagCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, zeroPad);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) {
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * input.rankOf() + 128;
|
||||||
|
|
||||||
|
PointersManager manager(context, "matrixSetDiag");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input, &diagonal});
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiagCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), diagonal.getSpecialBuffer(), diagonal.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), zeroPad), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input, &diagonal});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,95 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// Created by GS <sgazeos@gmail.com> on 3/21/2018.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "ResultSet.h"
|
|
||||||
#include <ops/declarable/helpers/matrix_diag.h>
|
|
||||||
#include <Status.h>
|
|
||||||
#include <ShapeUtils.h>
|
|
||||||
#include <ShapeUtils.h>
|
|
||||||
#include <TAD.h>
|
|
||||||
#include <cuda_exception.h>
|
|
||||||
#include <helpers/ConstantTadHelper.h>
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace helpers {
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void matrixDiagKernel(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
|
|
||||||
Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets) {
|
|
||||||
int totalThreads = blockDim.x;
|
|
||||||
for (Nd4jLong i = blockIdx.x; i < numTads; i += gridDim.x) {
|
|
||||||
auto yOffset = tadInputOffsets[i];
|
|
||||||
auto xOffset = tadOutputOffsets[i];
|
|
||||||
for (Nd4jLong j = threadIdx.x; j < inputLength; j += totalThreads) {
|
|
||||||
Nd4jLong coords[2] = {j, j};
|
|
||||||
Nd4jLong tadOffset = shape::getOffset(0, shape::shapeOf(tadOnlyOutputShapeInfo), shape::stride(tadOnlyOutputShapeInfo), coords, 2);
|
|
||||||
//shape::getIndexOffset(j, tadOnlyOutputShapeInfo, inputLength)
|
|
||||||
*(reinterpret_cast<T*>(outputBuffer) + xOffset + tadOffset) = *(reinterpret_cast<T const*>(inputBuffer) + yOffset + shape::getIndexOffset(j, tadOnlyInputShapeInfo, inputLength));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
// Returns a batched matrix tensor with new batched diagonal values.
|
|
||||||
// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static int _matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) {
|
|
||||||
cudaStream_t* stream = context->getCudaStream();
|
|
||||||
//auto listOut = output->allTensorsAlongDimension({output->rankOf() - 2, output->rankOf() - 1});
|
|
||||||
//auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 1});
|
|
||||||
|
|
||||||
//auto repeatDelta = shape::prodLong(newShape.data(), rank) / this->lengthOf();
|
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), {input->rankOf() - 1});
|
|
||||||
const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs(input->getShapeInfo(), dimsToExclude); //this->tensorsAlongDimension({dimension});
|
|
||||||
//printf("Repeat delta %lld, numTads %lld\n", repeatDelta, numTads);
|
|
||||||
//tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets;
|
|
||||||
std::vector<int> inputDims({input->rankOf() - 1});
|
|
||||||
std::vector<int> outputDims({output->rankOf() - 2, output->rankOf() - 1});
|
|
||||||
|
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), inputDims);
|
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), outputDims);
|
|
||||||
|
|
||||||
if (!input->isActualOnDeviceSide())
|
|
||||||
input->syncToDevice();
|
|
||||||
|
|
||||||
if (!output->isActualOnDeviceSide())
|
|
||||||
output->syncToDevice();
|
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
|
||||||
cudaError_t cudaResult;
|
|
||||||
|
|
||||||
dim3 launchDims(256, 512, 8192);
|
|
||||||
matrixDiagKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(input->getSpecialBuffer(), output->getSpecialBuffer(), numTads, input->sizeAt(-1), packX.specialShapeInfo(), packX.specialOffsets(), packZ.specialShapeInfo(), packZ.specialOffsets());
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
int matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) {
|
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiag, (context, input, output), LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int _matrixDiag, (nd4j::LaunchContext * context, const NDArray* input, NDArray* output), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -28,8 +28,7 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output);
|
void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author GS <sgazeos@gmail.com>
|
|
||||||
//
|
|
||||||
#ifndef __MATRIX_DIAG_HELPERS__
|
|
||||||
#define __MATRIX_DIAG_HELPERS__
|
|
||||||
#include <op_boilerplate.h>
|
|
||||||
#include <NDArray.h>
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace helpers {
|
|
||||||
|
|
||||||
int matrixDiag(nd4j::LaunchContext * context, NDArray const* input, NDArray* output);
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
|
@ -117,9 +117,9 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) {
|
||||||
|
|
||||||
auto v = result->at(0);
|
auto v = result->at(0);
|
||||||
auto i = result->at(1);
|
auto i = result->at(1);
|
||||||
v->printIndexedBuffer("Values");
|
// v->printIndexedBuffer("Values");
|
||||||
i->printIndexedBuffer("Indices");
|
// i->printIndexedBuffer("Indices");
|
||||||
i->printShapeInfo("Indices shape");
|
// i->printShapeInfo("Indices shape");
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
ASSERT_TRUE(expV.equalsTo(v));
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
|
||||||
|
@ -145,12 +145,12 @@ TEST_F(DeclarableOpsTests3, Test_Unique_2) {
|
||||||
auto i = result->at(1);
|
auto i = result->at(1);
|
||||||
auto c = result->at(2);
|
auto c = result->at(2);
|
||||||
|
|
||||||
v->printShapeInfo();
|
// v->printShapeInfo();
|
||||||
v->printIndexedBuffer("Values");
|
// v->printIndexedBuffer("Values");
|
||||||
i->printShapeInfo();
|
// i->printShapeInfo();
|
||||||
i->printIndexedBuffer("Indices");
|
// i->printIndexedBuffer("Indices");
|
||||||
c->printShapeInfo();
|
// c->printShapeInfo();
|
||||||
c->printIndexedBuffer("Counts");
|
// c->printIndexedBuffer("Counts");
|
||||||
|
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
ASSERT_TRUE(expV.equalsTo(v));
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
@ -200,11 +200,11 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) {
|
||||||
auto result1 = op.execute({&x}, {1.}, {1});
|
auto result1 = op.execute({&x}, {1.}, {1});
|
||||||
ASSERT_EQ(result1->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result1->status(), ND4J_STATUS_OK);
|
||||||
auto z1 = result1->at(0);
|
auto z1 = result1->at(0);
|
||||||
z1->printIndexedBuffer("Z1");
|
// z1->printIndexedBuffer("Z1");
|
||||||
auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false);
|
auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false);
|
||||||
exp1.printIndexedBuffer("EXP1");
|
// exp1.printIndexedBuffer("EXP1");
|
||||||
z1->printShapeInfo("Z1 shape");
|
// z1->printShapeInfo("Z1 shape");
|
||||||
exp1.printShapeInfo("EXP1 shape");
|
// exp1.printShapeInfo("EXP1 shape");
|
||||||
ASSERT_TRUE(exp1.isSameShape(z1));
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
||||||
ASSERT_TRUE(exp1.equalsTo(z1));
|
ASSERT_TRUE(exp1.equalsTo(z1));
|
||||||
|
|
||||||
|
@ -714,7 +714,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
|
||||||
|
|
||||||
auto exp = MmulHelper::mmul(&x, &y);
|
auto exp = MmulHelper::mmul(&x, &y);
|
||||||
|
|
||||||
exp->printShapeInfo("exp shape");
|
// exp->printShapeInfo("exp shape");
|
||||||
|
|
||||||
nd4j::ops::batched_gemm op;
|
nd4j::ops::batched_gemm op;
|
||||||
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
||||||
|
|
|
@ -79,7 +79,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_2) {
|
||||||
sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
|
sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
k.printIndexedBuffer("KEYS");
|
// k.printIndexedBuffer("KEYS");
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
}
|
}
|
||||||
|
@ -98,8 +98,8 @@ TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
|
|
||||||
k.printIndexedBuffer("k");
|
// k.printIndexedBuffer("k");
|
||||||
v.printIndexedBuffer("v");
|
// v.printIndexedBuffer("v");
|
||||||
|
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
|
|
|
@ -1562,8 +1562,8 @@ public class DifferentialFunctionFactory {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable eluBp(SDVariable in, SDVariable epsilon) {
|
public SDVariable eluBp(SDVariable in, SDVariable epsilon, double alpha) {
|
||||||
return new EluBp(sameDiff(), in, epsilon).outputVariable();
|
return new EluBp(sameDiff(), in, epsilon, alpha).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* f(x) = alpha * (exp(x) - 1.0); x < 0
|
* f(x) = alpha * (exp(x) - 1.0); x < 0
|
||||||
|
@ -55,15 +53,7 @@ public class ActivationELU extends BaseActivationFunction {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public INDArray getActivation(INDArray in, boolean training) {
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
// no support in ELU native to override alpha
|
return Nd4j.exec(new ELU(in, in, alpha))[0];
|
||||||
if (this.alpha != 1.00) {
|
|
||||||
INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0];
|
|
||||||
alphaMultiple.muli(alpha);
|
|
||||||
BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
|
|
||||||
} else {
|
|
||||||
Nd4j.getExecutioner().execAndReturn(new ELU(in));
|
|
||||||
}
|
|
||||||
return in;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -1195,7 +1195,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3089,12 +3088,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return mmuli(other, result);
|
return mmuli(other, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* in place (element wise) division of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to divide
|
|
||||||
* @return the result of the divide
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray div(INDArray other) {
|
public INDArray div(INDArray other) {
|
||||||
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
||||||
|
@ -3104,25 +3097,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* copy (element wise) division of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to divide
|
|
||||||
* @param result the result ndarray
|
|
||||||
* @return the result of the divide
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray div(INDArray other, INDArray result) {
|
public INDArray div(INDArray other, INDArray result) {
|
||||||
validateNumericalArray("div", true);
|
validateNumericalArray("div", true);
|
||||||
return divi(other, result);
|
return divi(other, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* copy (element wise) multiplication of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to multiply
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mul(INDArray other) {
|
public INDArray mul(INDArray other) {
|
||||||
validateNumericalArray("mul", false);
|
validateNumericalArray("mul", false);
|
||||||
|
@ -3134,24 +3114,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* copy (element wise) multiplication of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to multiply
|
|
||||||
* @param result the result ndarray
|
|
||||||
* @return the result of the multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mul(INDArray other, INDArray result) {
|
public INDArray mul(INDArray other, INDArray result) {
|
||||||
return muli(other, result);
|
return muli(other, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* copy subtraction of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to subtract
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray sub(INDArray other) {
|
public INDArray sub(INDArray other) {
|
||||||
validateNumericalArray("sub", false);
|
validateNumericalArray("sub", false);
|
||||||
|
@ -3162,24 +3129,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* copy subtraction of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to subtract
|
|
||||||
* @param result the result ndarray
|
|
||||||
* @return the result of the subtraction
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray sub(INDArray other, INDArray result) {
|
public INDArray sub(INDArray other, INDArray result) {
|
||||||
return subi(other, result);
|
return subi(other, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* copy addition of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray add(INDArray other) {
|
public INDArray add(INDArray other) {
|
||||||
validateNumericalArray("add", false);
|
validateNumericalArray("add", false);
|
||||||
|
@ -3190,65 +3144,29 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* copy addition of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to add
|
|
||||||
* @param result the result ndarray
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray add(INDArray other, INDArray result) {
|
public INDArray add(INDArray other, INDArray result) {
|
||||||
validateNumericalArray("add", false);
|
validateNumericalArray("add", false);
|
||||||
return addi(other, result);
|
return addi(other, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform an copy matrix multiplication
|
|
||||||
*
|
|
||||||
* @param other the other matrix to perform matrix multiply with
|
|
||||||
* @param transpose the transpose status of each ndarray
|
|
||||||
* @return the result of the matrix multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mmuli(INDArray other, MMulTranspose transpose) {
|
public INDArray mmuli(INDArray other, MMulTranspose transpose) {
|
||||||
validateNumericalArray("mmuli", false);
|
validateNumericalArray("mmuli", false);
|
||||||
return dup().mmuli(other, this,transpose);
|
return dup().mmuli(other, this,transpose);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform an copy matrix multiplication
|
|
||||||
*
|
|
||||||
* @param other the other matrix to perform matrix multiply with
|
|
||||||
* @return the result of the matrix multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mmuli(INDArray other) {
|
public INDArray mmuli(INDArray other) {
|
||||||
validateNumericalArray("mmuli", false);
|
validateNumericalArray("mmuli", false);
|
||||||
return dup().mmuli(other, this);
|
return dup().mmuli(other, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform an in place matrix multiplication
|
|
||||||
*
|
|
||||||
* @param other the other matrix to perform matrix multiply with
|
|
||||||
* @param result the result ndarray
|
|
||||||
* @return the result of the matrix multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
|
public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
|
||||||
return transpose.exec(this, other, result);
|
return transpose.exec(this, other, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform an copy matrix multiplication
|
|
||||||
*
|
|
||||||
* @param other the other matrix to perform matrix multiply with
|
|
||||||
* @param result the result ndarray
|
|
||||||
* @return the result of the matrix multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mmuli(INDArray other, INDArray result) {
|
public INDArray mmuli(INDArray other, INDArray result) {
|
||||||
validateNumericalArray("mmuli", false);
|
validateNumericalArray("mmuli", false);
|
||||||
|
@ -3347,24 +3265,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return Nd4j.create(shape, stride);
|
return Nd4j.create(shape, stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* in place (element wise) division of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to divide
|
|
||||||
* @return the result of the divide
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray divi(INDArray other) {
|
public INDArray divi(INDArray other) {
|
||||||
return divi(other, this);
|
return divi(other, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* in place (element wise) division of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to divide
|
|
||||||
* @param result the result ndarray
|
|
||||||
* @return the result of the divide
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray divi(INDArray other, INDArray result) {
|
public INDArray divi(INDArray other, INDArray result) {
|
||||||
validateNumericalArray("divi", false);
|
validateNumericalArray("divi", false);
|
||||||
|
@ -3373,24 +3278,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* in place (element wise) multiplication of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to multiply
|
|
||||||
* @return the result of the multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray muli(INDArray other) {
|
public INDArray muli(INDArray other) {
|
||||||
return muli(other, this);
|
return muli(other, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* in place (element wise) multiplication of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to multiply
|
|
||||||
* @param result the result ndarray
|
|
||||||
* @return the result of the multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray muli(INDArray other, INDArray result) {
|
public INDArray muli(INDArray other, INDArray result) {
|
||||||
validateNumericalArray("muli", false);
|
validateNumericalArray("muli", false);
|
||||||
|
@ -3399,12 +3291,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* in place subtraction of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to subtract
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray subi(INDArray other) {
|
public INDArray subi(INDArray other) {
|
||||||
return subi(other, this);
|
return subi(other, this);
|
||||||
|
@ -3425,24 +3311,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* in place addition of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray addi(INDArray other) {
|
public INDArray addi(INDArray other) {
|
||||||
return addi(other, this);
|
return addi(other, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* in place addition of two matrices
|
|
||||||
*
|
|
||||||
* @param other the second ndarray to add
|
|
||||||
* @param result the result ndarray
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray addi(INDArray other, INDArray result) {
|
public INDArray addi(INDArray other, INDArray result) {
|
||||||
validateNumericalArray("addi", false);
|
validateNumericalArray("addi", false);
|
||||||
|
@ -3451,25 +3324,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the normmax along the specified dimension
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the norm1 along
|
|
||||||
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
|
|
||||||
* @return the norm1 along the specified dimension
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray normmax(boolean keepDims, int... dimension) {
|
public INDArray normmax(boolean keepDims, int... dimension) {
|
||||||
validateNumericalArray("normmax", false);
|
validateNumericalArray("normmax", false);
|
||||||
return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension));
|
return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the normmax along the specified dimension
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the norm1 along
|
|
||||||
* @return the norm1 along the specified dimension
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray normmax(int... dimension) {
|
public INDArray normmax(int... dimension) {
|
||||||
return normmax(false, dimension);
|
return normmax(false, dimension);
|
||||||
|
@ -4071,49 +3931,23 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return reshape(Nd4j.order(), shape);
|
return reshape(Nd4j.order(), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the product along a given dimension
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the product along
|
|
||||||
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
|
|
||||||
* @return the product along the specified dimension
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray prod(boolean keepDims, int... dimension) {
|
public INDArray prod(boolean keepDims, int... dimension) {
|
||||||
validateNumericalArray("prod", false);
|
validateNumericalArray("prod", false);
|
||||||
return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension));
|
return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the product along a given dimension
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the product along
|
|
||||||
* @return the product along the specified dimension
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray prod(int... dimension) {
|
public INDArray prod(int... dimension) {
|
||||||
return prod(false, dimension);
|
return prod(false, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the overall mean of this ndarray
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the mean along
|
|
||||||
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
|
|
||||||
* @return the mean along the specified dimension of this ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mean(boolean keepDims, int... dimension) {
|
public INDArray mean(boolean keepDims, int... dimension) {
|
||||||
validateNumericalArray("mean", false);
|
validateNumericalArray("mean", false);
|
||||||
return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension));
|
return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the overall mean of this ndarray
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the mean along
|
|
||||||
* @return the mean along the specified dimension of this ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mean(int... dimension) {
|
public INDArray mean(int... dimension) {
|
||||||
return mean(false, dimension);
|
return mean(false, dimension);
|
||||||
|
@ -4136,50 +3970,24 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return mean(result, false, dimension);
|
return mean(result, false, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the overall variance of this ndarray
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the mean along
|
|
||||||
* @return the mean along the specified dimension of this ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray var(int... dimension) {
|
public INDArray var(int... dimension) {
|
||||||
validateNumericalArray("var", false);
|
validateNumericalArray("var", false);
|
||||||
return Nd4j.getExecutioner().exec(new Variance(this, dimension));
|
return Nd4j.getExecutioner().exec(new Variance(this, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the overall variance of this ndarray
|
|
||||||
*
|
|
||||||
* @param biasCorrected boolean on whether to apply corrected bias
|
|
||||||
* @param dimension the dimension to getScalar the mean along
|
|
||||||
* @return the mean along the specified dimension of this ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray var(boolean biasCorrected, int... dimension) {
|
public INDArray var(boolean biasCorrected, int... dimension) {
|
||||||
validateNumericalArray("var", false);
|
validateNumericalArray("var", false);
|
||||||
return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension));
|
return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the overall max of this ndarray
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the mean along
|
|
||||||
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
|
|
||||||
* @return the mean along the specified dimension of this ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray max(boolean keepDims, int... dimension) {
|
public INDArray max(boolean keepDims, int... dimension) {
|
||||||
validateNumericalArray("max", false);
|
validateNumericalArray("max", false);
|
||||||
return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension));
|
return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the overall max of this ndarray
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the mean along
|
|
||||||
* @return the mean along the specified dimension of this ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray max(int... dimension) {
|
public INDArray max(int... dimension) {
|
||||||
return max(false, dimension);
|
return max(false, dimension);
|
||||||
|
@ -4191,25 +3999,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return Nd4j.getExecutioner().exec(new AMax(this, dimension));
|
return Nd4j.getExecutioner().exec(new AMax(this, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the overall min of this ndarray
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the mean along
|
|
||||||
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
|
|
||||||
* @return the mean along the specified dimension of this ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray min(boolean keepDims, int... dimension) {
|
public INDArray min(boolean keepDims, int... dimension) {
|
||||||
validateNumericalArray("min", false);
|
validateNumericalArray("min", false);
|
||||||
return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension));
|
return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the overall min of this ndarray
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the mean along
|
|
||||||
* @return the mean along the specified dimension of this ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray min(int... dimension) {
|
public INDArray min(int... dimension) {
|
||||||
return min(false, dimension);
|
return min(false, dimension);
|
||||||
|
@ -4290,39 +4085,17 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return sum(result, false, dimension);
|
return sum(result, false, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the norm1 along the specified dimension
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the norm1 along
|
|
||||||
* @return the norm1 along the specified dimension
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray norm1(int... dimension) {
|
public INDArray norm1(int... dimension) {
|
||||||
return norm1(false, dimension);
|
return norm1(false, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the norm1 along the specified dimension
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the norm1 along
|
|
||||||
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
|
|
||||||
* @return the norm1 along the specified dimension
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray norm1(boolean keepDims, int... dimension) {
|
public INDArray norm1(boolean keepDims, int... dimension) {
|
||||||
validateNumericalArray("norm1", false);
|
validateNumericalArray("norm1", false);
|
||||||
return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension));
|
return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Standard deviation of an ndarray along a dimension
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the std along
|
|
||||||
* @return the standard deviation along a particular dimension
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray std(int... dimension) {
|
public INDArray std(int... dimension) {
|
||||||
return std(true, dimension);
|
return std(true, dimension);
|
||||||
|
@ -4345,32 +4118,17 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0);
|
return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the norm2 along the specified dimension
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the norm2 along
|
|
||||||
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
|
|
||||||
* @return the norm2 along the specified dimension
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray norm2(boolean keepDims, int... dimension) {
|
public INDArray norm2(boolean keepDims, int... dimension) {
|
||||||
validateNumericalArray("norm2", false);
|
validateNumericalArray("norm2", false);
|
||||||
return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension));
|
return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the norm2 along the specified dimension
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the norm2 along
|
|
||||||
* @return the norm2 along the specified dimension
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray norm2(int... dimension) {
|
public INDArray norm2(int... dimension) {
|
||||||
return norm2(false, dimension);
|
return norm2(false, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of columns (shape[1]), throws an exception when
|
* Number of columns (shape[1]), throws an exception when
|
||||||
* called when not 2d
|
* called when not 2d
|
||||||
|
|
|
@ -1232,8 +1232,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray normmax(boolean keepDims, int... dimension) {
|
public INDArray normmax(boolean keepDims, int... dimension) {
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -1404,7 +1404,13 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray add(INDArray other, INDArray result);
|
INDArray add(INDArray other, INDArray result);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform an copy matrix multiplication
|
||||||
|
*
|
||||||
|
* @param other the other matrix to perform matrix multiply with
|
||||||
|
* @param transpose the transpose status of each ndarray
|
||||||
|
* @return the result of the matrix multiplication
|
||||||
|
*/
|
||||||
INDArray mmuli(INDArray other, MMulTranspose transpose);
|
INDArray mmuli(INDArray other, MMulTranspose transpose);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1415,7 +1421,13 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray mmuli(INDArray other);
|
INDArray mmuli(INDArray other);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform an in place matrix multiplication
|
||||||
|
*
|
||||||
|
* @param other the other matrix to perform matrix multiply with
|
||||||
|
* @param result the result ndarray
|
||||||
|
* @return the result of the matrix multiplication
|
||||||
|
*/
|
||||||
INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose);
|
INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1497,7 +1509,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray addi(INDArray other, INDArray result);
|
INDArray addi(INDArray other, INDArray result);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
|
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
|
||||||
*
|
*
|
||||||
|
@ -1506,7 +1517,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray normmax(int... dimension);
|
INDArray normmax(int... dimension);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
|
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
|
||||||
*
|
*
|
||||||
|
@ -1585,7 +1595,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
/**
|
/**
|
||||||
* Calculate the standard deviation for the entire array
|
* Calculate the standard deviation for the entire array
|
||||||
*
|
*
|
||||||
* @return
|
* @return standard deviation
|
||||||
*/
|
*/
|
||||||
Number stdNumber();
|
Number stdNumber();
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,9 @@ public class EluBp extends DynamicCustomOp {
|
||||||
|
|
||||||
public EluBp(){ }
|
public EluBp(){ }
|
||||||
|
|
||||||
public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
public EluBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){
|
||||||
super(sd, new SDVariable[]{input, gradient});
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
addTArgument(alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {
|
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {
|
||||||
|
|
|
@ -23,13 +23,9 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
|
||||||
import org.tensorflow.framework.GraphDef;
|
|
||||||
import org.tensorflow.framework.NodeDef;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ELU: Exponential Linear Unit (alpha=1.0)<br>
|
* ELU: Exponential Linear Unit (alpha=1.0)<br>
|
||||||
|
@ -41,19 +37,31 @@ import java.util.Map;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
public class ELU extends DynamicCustomOp {
|
public class ELU extends DynamicCustomOp {
|
||||||
|
public static final double DEFAULT_ALPHA = 1.0;
|
||||||
|
|
||||||
|
protected double alpha;
|
||||||
|
|
||||||
public ELU(SameDiff sameDiff, SDVariable i_v) {
|
public ELU(SameDiff sameDiff, SDVariable i_v) {
|
||||||
super(sameDiff, new SDVariable[]{i_v});
|
super(sameDiff, new SDVariable[]{i_v});
|
||||||
|
this.alpha = DEFAULT_ALPHA;
|
||||||
|
addTArgument(alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU() {
|
public ELU() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU(INDArray x, INDArray z) {
|
public ELU(INDArray x, INDArray z) {
|
||||||
|
this(x, z, DEFAULT_ALPHA);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ELU(INDArray x, INDArray z, double alpha) {
|
||||||
super(null, wrapOrNull(x), wrapOrNull(z));
|
super(null, wrapOrNull(x), wrapOrNull(z));
|
||||||
|
this.alpha = alpha;
|
||||||
|
addTArgument(alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU(INDArray x) {
|
public ELU(INDArray x) {
|
||||||
this(x, null);
|
this(x, null, DEFAULT_ALPHA);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -75,7 +83,7 @@ public class ELU extends DynamicCustomOp {
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
//ELU: e^x-1 if x<0, x otherwise
|
//ELU: e^x-1 if x<0, x otherwise
|
||||||
//dL/dIn = dL/Out * dOut/dIn
|
//dL/dIn = dL/Out * dOut/dIn
|
||||||
return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
|
return Collections.singletonList(f().eluBp(arg(), i_v.get(0), alpha));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.jita.allocator.pointers.cuda;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
||||||
import org.nd4j.linalg.exception.ND4JException;
|
import org.nd4j.linalg.exception.ND4JException;
|
||||||
|
@ -69,8 +70,9 @@ public class cudaEvent_t extends CudaPointer {
|
||||||
if (res == 0)
|
if (res == 0)
|
||||||
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
|
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
|
||||||
|
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
|
val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||||
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
|
if (code != 0)
|
||||||
|
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,8 +80,9 @@ public class cudaEvent_t extends CudaPointer {
|
||||||
if (!isDestroyed()) {
|
if (!isDestroyed()) {
|
||||||
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream);
|
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream);
|
||||||
|
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
|
val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||||
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
|
if (code != 0)
|
||||||
|
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.jita.handler.impl;
|
package org.nd4j.jita.handler.impl;
|
||||||
|
|
||||||
|
import org.nd4j.nativeblas.OpaqueLaunchContext;
|
||||||
import org.nd4j.shade.guava.collect.HashBasedTable;
|
import org.nd4j.shade.guava.collect.HashBasedTable;
|
||||||
import org.nd4j.shade.guava.collect.Table;
|
import org.nd4j.shade.guava.collect.Table;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
@ -105,6 +106,8 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
|
|
||||||
private final AllocationStatus INITIAL_LOCATION;
|
private final AllocationStatus INITIAL_LOCATION;
|
||||||
|
|
||||||
|
private final List<cublasHandle_t> cublasHandles = new ArrayList<>();
|
||||||
|
|
||||||
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
|
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -162,6 +165,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
|
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
|
||||||
for (int i = 0; i < numDevices; i++) {
|
for (int i = 0; i < numDevices; i++) {
|
||||||
deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
|
deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
|
||||||
|
cublasHandles.add(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) {
|
if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) {
|
||||||
|
@ -1176,6 +1180,25 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
return getCudaContext();
|
return getCudaContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
|
||||||
|
|
||||||
|
protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
|
||||||
|
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||||
|
try {
|
||||||
|
lock.writeLock().lock();
|
||||||
|
|
||||||
|
if (cublasHandles.get(deviceId) == null) {
|
||||||
|
cublasHandles.remove(deviceId);
|
||||||
|
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return cublasHandles.get(deviceId);
|
||||||
|
} finally {
|
||||||
|
lock.writeLock().unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns CudaContext for current thread. If context doesn't exist - it gets created first.
|
* This method returns CudaContext for current thread. If context doesn't exist - it gets created first.
|
||||||
* @return
|
* @return
|
||||||
|
@ -1183,8 +1206,6 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
public CudaContext getCudaContext() {
|
public CudaContext getCudaContext() {
|
||||||
val lc = nativeOps.defaultLaunchContext();
|
val lc = nativeOps.defaultLaunchContext();
|
||||||
|
|
||||||
// TODO: maybe make ThreadLocal cache for context?
|
|
||||||
|
|
||||||
return CudaContext.builder()
|
return CudaContext.builder()
|
||||||
.bufferScalar(nativeOps.lcScalarPointer(lc))
|
.bufferScalar(nativeOps.lcScalarPointer(lc))
|
||||||
.bufferReduction(nativeOps.lcReductionPointer(lc))
|
.bufferReduction(nativeOps.lcReductionPointer(lc))
|
||||||
|
@ -1192,7 +1213,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
.bufferSpecial(nativeOps.lcScalarPointer(lc))
|
.bufferSpecial(nativeOps.lcScalarPointer(lc))
|
||||||
.oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc)))
|
.oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc)))
|
||||||
.specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc)))
|
.specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc)))
|
||||||
.cublasHandle(new cublasHandle_t(nativeOps.lcBlasHandle(lc)))
|
.cublasHandle(getCudaCublasHandle(lc))
|
||||||
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
|
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.jcublas.blas;
|
package org.nd4j.linalg.jcublas.blas;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.DoublePointer;
|
import org.bytedeco.javacpp.DoublePointer;
|
||||||
import org.bytedeco.javacpp.FloatPointer;
|
import org.bytedeco.javacpp.FloatPointer;
|
||||||
|
@ -52,6 +53,7 @@ import static org.nd4j.linalg.jcublas.blas.CudaBlas.*;
|
||||||
*
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class JcublasLevel3 extends BaseLevel3 {
|
public class JcublasLevel3 extends BaseLevel3 {
|
||||||
private Allocator allocator = AtomicAllocator.getInstance();
|
private Allocator allocator = AtomicAllocator.getInstance();
|
||||||
private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas();
|
private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas();
|
||||||
|
@ -78,7 +80,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
|
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
|
||||||
|
|
||||||
if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch == 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) {
|
if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch >= 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) {
|
||||||
// on these selected archs we run with cublasHgemm
|
// on these selected archs we run with cublasHgemm
|
||||||
__half alphaHalf = new __half();
|
__half alphaHalf = new __half();
|
||||||
__half betaHalf = new __half();
|
__half betaHalf = new __half();
|
||||||
|
@ -96,7 +98,11 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda,
|
new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda,
|
||||||
(ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta),
|
(ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta),
|
||||||
(ShortPointer) cCPointer.getDevicePointer(), 2, ldc);
|
(ShortPointer) cCPointer.getDevicePointer(), 2, ldc);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx.getOldStream().synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
allocator.registerAction(ctx, C, A, B);
|
allocator.registerAction(ctx, C, A, B);
|
||||||
|
@ -114,18 +120,24 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
val ctx = allocator.getFlowController().prepareAction(C, A, B);
|
val ctx = allocator.getFlowController().prepareAction(C, A, B);
|
||||||
|
|
||||||
|
//log.info("Synchronizing CUDA stream");
|
||||||
|
ctx.getOldStream().synchronize();
|
||||||
|
|
||||||
val cAPointer = new CublasPointer(A, ctx);
|
val cAPointer = new CublasPointer(A, ctx);
|
||||||
val cBPointer = new CublasPointer(B, ctx);
|
val cBPointer = new CublasPointer(B, ctx);
|
||||||
val cCPointer = new CublasPointer(C, ctx);
|
val cCPointer = new CublasPointer(C, ctx);
|
||||||
|
|
||||||
val handle = ctx.getCublasHandle();
|
val handle = ctx.getCublasHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
|
//log.info("Handle: {}; Stream: {}", handle.address(), ctx.getCublasStream().address());
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
|
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
|
||||||
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
|
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
|
||||||
(FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta),
|
(FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta),
|
||||||
(FloatPointer) cCPointer.getDevicePointer(), ldc);
|
(FloatPointer) cCPointer.getDevicePointer(), ldc);
|
||||||
|
|
||||||
|
ctx.getOldStream().synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
allocator.registerAction(ctx, C, A, B);
|
allocator.registerAction(ctx, C, A, B);
|
||||||
|
@ -244,6 +256,8 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
|
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
|
||||||
(DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta),
|
(DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta),
|
||||||
(DoublePointer) cCPointer.getDevicePointer(), ldc);
|
(DoublePointer) cCPointer.getDevicePointer(), ldc);
|
||||||
|
|
||||||
|
ctx.getOldStream().synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
allocator.registerAction(ctx, C, A, B);
|
allocator.registerAction(ctx, C, A, B);
|
||||||
|
|
|
@ -2548,6 +2548,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||||
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
@ -2562,6 +2565,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
||||||
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
@ -2577,6 +2583,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
|
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
|
||||||
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
|
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
@ -2590,6 +2599,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
|
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
|
||||||
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
|
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
|
|
@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
|
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
|
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -16982,8 +16985,20 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a batched matrix tensor with new batched diagonal values.
|
* Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
|
||||||
*/
|
*
|
||||||
|
* Input arrays:
|
||||||
|
* input: input array, considered as batch of matrices
|
||||||
|
* diagonal: array containing elements to be inserted into input array,
|
||||||
|
* following rank condition should be satisfied: diagonal_rank = input_rank - 1,
|
||||||
|
* the shapes of diagonal and input arrays must be equal except last dimension of input array,
|
||||||
|
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
|
||||||
|
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
|
||||||
|
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
|
||||||
|
*
|
||||||
|
* Output array:
|
||||||
|
* has the same shape as input, corresponding diagonal elements are substituted
|
||||||
|
*/
|
||||||
// #if NOT_EXCLUDED(OP_matrix_set_diag)
|
// #if NOT_EXCLUDED(OP_matrix_set_diag)
|
||||||
@Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {
|
@Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {
|
||||||
static { Loader.load(); }
|
static { Loader.load(); }
|
||||||
|
|
Loading…
Reference in New Issue