From 7abc574eeb92627a20213e47e63dbce05bf05fda Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 3 Sep 2019 22:02:02 +0300 Subject: [PATCH] Snapshot update (#8194) * fix double consumption of rng on cpu Signed-off-by: raver119 * Shyrma docs (#222) * - documenting and profiling matrix_set_diag cuda kernel Signed-off-by: Yurii * - correct formula of pnorm pooling in cuda 2d/3d kernels - remove helper matrix_diag which duplicates work of helper matrix_set_diag Signed-off-by: Yurii * cublasHandle sharing + lock Signed-off-by: raver119 * cublasHandle sharing + lock Signed-off-by: raver119 * Documentation from serialization/deserialization in NLP (#221) * refactoring Signed-off-by: Alexander Stoyakin * Javadocs Signed-off-by: Alexander Stoyakin * Javadoc fixed Signed-off-by: Alexander Stoyakin * Cleanup Signed-off-by: Alexander Stoyakin * dedicated lock for getCudaCublasHandle Signed-off-by: raver119 * Small fixes (#223) Signed-off-by: AlexDBlack * ELU DL4J fixes (#224) Signed-off-by: AlexDBlack * javadoc (#225) Signed-off-by: Robert Altena * Small test compilation fix (#226) Signed-off-by: AlexDBlack * #8182 remove spark version suffix (#227) Signed-off-by: AlexDBlack * [WIP] Thread safety (#229) * sync after cublas*gemm Signed-off-by: raver119 * mutex for CublasHelper Signed-off-by: raver119 * don't store cublasHandle in LaunchContext, it's per-device anyway Signed-off-by: raver119 * some printout Signed-off-by: raver119 * check for field instead Signed-off-by: raver119 * pew-pew Signed-off-by: raver119 * don't release ContextBuffers until device changed Signed-off-by: raver119 * small tweak Signed-off-by: raver119 * some logging in sgemm Signed-off-by: raver119 * stream sync Signed-off-by: raver119 * some more logging Signed-off-by: raver119 * some more error checks Signed-off-by: raver119 * one fancy test Signed-off-by: raver119 * one fancy test Signed-off-by: raver119 * minor AffinityManager fix Signed-off-by: raver119 * cudaEvent error logging improvement Signed-off-by: raver119 * ConstantHelper thread safety Signed-off-by: raver119 * - minor corrections in ConstantTadHelper Signed-off-by: Yurii * ConstantShapeHelper thread safety Signed-off-by: raver119 * ConstantTadHelper.cu updated Signed-off-by: raver119 * logging off Signed-off-by: raver119 * logging off Signed-off-by: raver119 --- .../datavec-spark-inference-client/pom.xml | 2 +- .../datavec-spark-inference-server/pom.xml | 2 +- datavec/datavec-spark/pom.xml | 2 +- .../java/org/deeplearning4j/RandomTests.java | 63 +++ .../models/WordVectorSerializerTest.java | 12 +- .../loader/WordVectorSerializer.java | 415 +++++++++++++----- .../deeplearning4j-aws/pom.xml | 2 +- .../spark/dl4j-spark-nlp-java8/pom.xml | 2 +- .../spark/dl4j-spark-nlp/pom.xml | 2 +- .../spark/dl4j-spark-parameterserver/pom.xml | 2 +- .../spark/dl4j-spark/pom.xml | 2 +- .../org/deeplearning4j/spark/TestKryo.java | 4 +- .../multilayer/TestSparkDl4jMultiLayer.java | 2 +- ...TestSparkMultiLayerParameterAveraging.java | 42 +- .../deeplearning4j-scaleout/spark/pom.xml | 4 +- libnd4j/include/array/ConstantHolder.h | 4 + libnd4j/include/array/impl/ConstantHolder.cpp | 4 + .../include/execution/cuda/AffinityManager.cu | 21 +- .../include/execution/cuda/ContextBuffers.cu | 21 +- .../include/execution/cuda/LaunchContext.cu | 13 +- libnd4j/include/helpers/ConstantHelper.h | 3 +- libnd4j/include/helpers/ConstantShapeHelper.h | 8 +- libnd4j/include/helpers/ConstantTadHelper.h | 10 +- .../include/helpers/cpu/ConstantHelper.cpp | 27 +- .../helpers/cpu/ConstantShapeHelper.cpp | 12 +- .../include/helpers/cpu/ConstantTadHelper.cpp | 12 +- libnd4j/include/helpers/cublasHelper.h | 2 + .../include/helpers/cuda/ConstantHelper.cu | 31 +- .../helpers/cuda/ConstantShapeHelper.cu | 12 +- .../include/helpers/cuda/ConstantTadHelper.cu | 14 +- .../include/helpers/cuda_off/cublasHelper.cu | 9 +- libnd4j/include/loops/cpu/random.cpp | 8 - .../generic/parity_ops/matrixSetDiag.cpp | 7 +- .../generic/parity_ops/matrix_diag.cpp | 64 +-- .../ops/declarable/headers/parity_ops.h | 16 +- .../declarable/helpers/cpu/convolutions.cpp | 2 +- .../declarable/helpers/cpu/matrixSetDiag.cpp | 57 ++- .../declarable/helpers/cpu/matrix_diag.cpp | 65 --- .../declarable/helpers/cuda/convolutions.cu | 23 +- .../declarable/helpers/cuda/matrixSetDiag.cu | 110 +++-- .../declarable/helpers/cuda/matrix_diag.cu | 95 ---- .../ops/declarable/helpers/matrixSetDiag.h | 3 +- .../ops/declarable/helpers/matrix_diag.h | 34 -- .../layers_tests/DeclarableOpsTests3.cpp | 28 +- .../tests_cpu/layers_tests/SortCudaTests.cu | 6 +- .../DifferentialFunctionFactory.java | 4 +- .../activations/impl/ActivationELU.java | 16 +- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 242 ---------- .../linalg/api/ndarray/BaseSparseNDArray.java | 2 - .../org/nd4j/linalg/api/ndarray/INDArray.java | 20 +- .../ops/impl/transforms/gradient/EluBp.java | 3 +- .../api/ops/impl/transforms/strict/ELU.java | 20 +- .../allocator/pointers/cuda/cudaEvent_t.java | 11 +- .../jita/handler/impl/CudaZeroHandler.java | 27 +- .../linalg/jcublas/blas/JcublasLevel3.java | 16 +- .../ops/executioner/CudaExecutioner.java | 12 + .../java/org/nd4j/nativeblas/Nd4jCuda.java | 3 + .../java/org/nd4j/nativeblas/Nd4jCpu.java | 19 +- 58 files changed, 835 insertions(+), 839 deletions(-) create mode 100644 deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java delete mode 100644 libnd4j/include/ops/declarable/helpers/cpu/matrix_diag.cpp delete mode 100644 libnd4j/include/ops/declarable/helpers/cuda/matrix_diag.cu delete mode 100644 libnd4j/include/ops/declarable/helpers/matrix_diag.h diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml index db110703b..076c22ab9 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml @@ -38,7 +38,7 @@ org.datavec datavec-spark-inference-server_2.11 - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT test diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml index 605b13b70..8bef216a7 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml @@ -25,7 +25,7 @@ datavec-spark-inference-server_2.11 jar - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT datavec-spark-inference-server diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml index 05c505cac..f7143c6ea 100644 --- a/datavec/datavec-spark/pom.xml +++ b/datavec/datavec-spark/pom.xml @@ -24,7 +24,7 @@ 4.0.0 - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT datavec-spark_2.11 diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java new file mode 100644 index 000000000..8f727fdf9 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java @@ -0,0 +1,63 @@ +package org.deeplearning4j; + +import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Ignore; +import org.junit.Test; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.RmsProp; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import java.util.concurrent.CountDownLatch; + +@Ignore +public class RandomTests { + + @Test + public void testReproduce() throws Exception { + + final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10) + .activation(Activation.TANH).build()) + .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10) + .activation(Activation.SOFTMAX).build()) + .build(); + + for (int e = 0; e < 3; e++) { + + int nThreads = 10; + final CountDownLatch l = new CountDownLatch(nThreads); + for (int i = 0; i < nThreads; i++) { + final int j = i; + Thread t = new Thread(new Runnable() { + @Override + public void run() { + try { + MultiLayerNetwork net = new MultiLayerNetwork(conf.clone()); + net.init(); + DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(10, false, 12345), 100); + net.fit(iter); + } catch (Throwable t) { + System.out.println("Thread failed: " + j); + t.printStackTrace(); + } finally { + l.countDown(); + } + } + }); + t.start(); + } + + l.await(); + System.out.println("DONE " + e + "\n"); + } + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java index f4dd1a6c5..69eae7307 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java @@ -833,14 +833,14 @@ public class WordVectorSerializerTest extends BaseDL4JTest { public void testB64_1() throws Exception { String wordA = "night"; String wordB = "night day"; - String encA = WordVectorSerializer.encodeB64(wordA); - String encB = WordVectorSerializer.encodeB64(wordB); + String encA = WordVectorSerializer.ReadHelper.encodeB64(wordA); + String encB = WordVectorSerializer.ReadHelper.encodeB64(wordB); - assertEquals(wordA, WordVectorSerializer.decodeB64(encA)); - assertEquals(wordB, WordVectorSerializer.decodeB64(encB)); + assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(encA)); + assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(encB)); - assertEquals(wordA, WordVectorSerializer.decodeB64(wordA)); - assertEquals(wordB, WordVectorSerializer.decodeB64(wordB)); + assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(wordA)); + assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(wordB)); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java index cce6a740a..80ce0bf34 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java @@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.io.LineIterator; import org.apache.commons.io.output.CloseShieldOutputStream; -import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; @@ -52,7 +51,6 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.deeplearning4j.util.DL4JFileUtils; -import org.nd4j.base.Preconditions; import org.nd4j.compression.impl.NoOp; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -68,8 +66,6 @@ import org.nd4j.util.OneTimeLogger; import java.io.*; import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; @@ -78,6 +74,80 @@ import java.util.zip.*; /** * This is utility class, providing various methods for WordVectors serialization * + * List of available serialization methods (please keep this list consistent with source code): + * + *
    + *
  • Serializers for Word2Vec:
  • + * {@link #writeWordVectors(WeightLookupTable, File)} + * {@link #writeWordVectors(WeightLookupTable, OutputStream)} + * {@link #writeWord2VecModel(Word2Vec, File)} + * {@link #writeWord2VecModel(Word2Vec, String)} + * {@link #writeWord2VecModel(Word2Vec, OutputStream)} + * + *
  • Deserializers for Word2Vec:
  • + * {@link #readWord2VecModel(File)} + * {@link #readWord2VecModel(String)} + * {@link #readWord2VecModel(File, boolean)} + * {@link #readWord2VecModel(String, boolean)} + * {@link #readAsBinaryNoLineBreaks(File)} + * {@link #readAsBinary(File)} + * {@link #readAsCsv(File)} + * {@link #readBinaryModel(File, boolean, boolean)} + * {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)} + * {@link #readWord2Vec(String, boolean)} + * {@link #readWord2Vec(File, boolean)} + * {@link #readWord2Vec(InputStream, boolean)} + * + *
  • Serializers for ParaVec:
  • + * {@link #writeParagraphVectors(ParagraphVectors, File)} + * {@link #writeParagraphVectors(ParagraphVectors, String)} + * {@link #writeParagraphVectors(ParagraphVectors, OutputStream)} + * + *
  • Deserializers for ParaVec:
  • + * {@link #readParagraphVectors(File)} + * {@link #readParagraphVectors(String)} + * {@link #readParagraphVectors(InputStream)} + * + *
  • Serializers for GloVe:
  • + * {@link #writeWordVectors(Glove, File)} + * {@link #writeWordVectors(Glove, String)} + * {@link #writeWordVectors(Glove, OutputStream)} + * + *
  • Adapters
  • + * {@link #fromTableAndVocab(WeightLookupTable, VocabCache)} + * {@link #fromPair(Pair)} + * {@link #loadTxt(File)} + * + *
  • Serializers to tSNE format
  • + * {@link #writeTsneFormat(Glove, INDArray, File)} + * {@link #writeTsneFormat(Word2Vec, INDArray, File)} + * + *
  • FastText serializer:
  • + * {@link #writeWordVectors(FastText, File)} + * + *
  • FastText deserializer:
  • + * {@link #readWordVectors(File)} + * + *
  • SequenceVectors serializers:
  • + * {@link #writeSequenceVectors(SequenceVectors, OutputStream)} + * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, File)} + * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, String)} + * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, OutputStream)} + * {@link #writeLookupTable(WeightLookupTable, File)} + * {@link #writeVocabCache(VocabCache, File)} + * {@link #writeVocabCache(VocabCache, OutputStream)} + * + *
  • SequenceVectors deserializers:
  • + * {@link #readSequenceVectors(File, boolean)} + * {@link #readSequenceVectors(String, boolean)} + * {@link #readSequenceVectors(SequenceElementFactory, File)} + * {@link #readSequenceVectors(InputStream, boolean)} + * {@link #readSequenceVectors(SequenceElementFactory, InputStream)} + * {@link #readLookupTable(File)} + * {@link #readLookupTable(InputStream)} + * + *
+ * * @author Adam Gibson * @author raver119 * @author alexander@skymind.io @@ -97,7 +167,7 @@ public class WordVectorSerializer { * @throws IOException * @throws NumberFormatException */ - private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { + /*private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { InMemoryLookupTable lookupTable; VocabCache cache; INDArray syn0; @@ -142,7 +212,7 @@ public class WordVectorSerializer { ret.setLookupTable(lookupTable); } return ret; - } + }*/ /** * Read a binary word2vec file. @@ -173,8 +243,8 @@ public class WordVectorSerializer { try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile)); DataInputStream dis = new DataInputStream(bis)) { - words = Integer.parseInt(readString(dis)); - size = Integer.parseInt(readString(dis)); + words = Integer.parseInt(ReadHelper.readString(dis)); + size = Integer.parseInt(ReadHelper.readString(dis)); syn0 = Nd4j.create(words, size); cache = new AbstractCache<>(); @@ -188,11 +258,11 @@ public class WordVectorSerializer { float[] vector = new float[size]; for (int i = 0; i < words; i++) { - word = readString(dis); + word = ReadHelper.readString(dis); log.trace("Loading " + word + " with word " + i); for (int j = 0; j < size; j++) { - vector[j] = readFloat(dis); + vector[j] = ReadHelper.readFloat(dis); } if (cache.containsWord(word)) @@ -236,64 +306,6 @@ public class WordVectorSerializer { } - /** - * Read a float from a data input stream Credit to: - * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java - * - * @param is - * @return - * @throws IOException - */ - public static float readFloat(InputStream is) throws IOException { - byte[] bytes = new byte[4]; - is.read(bytes); - return getFloat(bytes); - } - - /** - * Read a string from a data input stream Credit to: - * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java - * - * @param b - * @return - * @throws IOException - */ - public static float getFloat(byte[] b) { - int accum = 0; - accum = accum | (b[0] & 0xff) << 0; - accum = accum | (b[1] & 0xff) << 8; - accum = accum | (b[2] & 0xff) << 16; - accum = accum | (b[3] & 0xff) << 24; - return Float.intBitsToFloat(accum); - } - - /** - * Read a string from a data input stream Credit to: - * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java - * - * @param dis - * @return - * @throws IOException - */ - public static String readString(DataInputStream dis) throws IOException { - byte[] bytes = new byte[MAX_SIZE]; - byte b = dis.readByte(); - int i = -1; - StringBuilder sb = new StringBuilder(); - while (b != 32 && b != 10) { - i++; - bytes[i] = b; - b = dis.readByte(); - if (i == 49) { - sb.append(new String(bytes, "UTF-8")); - i = -1; - bytes = new byte[MAX_SIZE]; - } - } - sb.append(new String(bytes, 0, i + 1, "UTF-8")); - return sb.toString(); - } - /** * This method writes word vectors to the given path. * Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network. @@ -355,7 +367,7 @@ public class WordVectorSerializer { val builder = new StringBuilder(); val l = element.getLabel(); - builder.append(encodeB64(l)).append(" "); + builder.append(ReadHelper.encodeB64(l)).append(" "); val vec = lookupTable.vector(element.getLabel()); for (int i = 0; i < vec.length(); i++) { builder.append(vec.getDouble(i)); @@ -518,7 +530,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int code : word.getCodes()) { builder.append(code).append(" "); } @@ -536,7 +548,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int point : word.getPoints()) { builder.append(point).append(" "); } @@ -554,7 +566,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ") + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ") .append(word.getElementFrequency()).append(" ") .append(vectors.getVocab().docAppearedIn(word.getLabel())); @@ -638,7 +650,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int code : word.getCodes()) { builder.append(code).append(" "); } @@ -656,7 +668,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int point : word.getPoints()) { builder.append(point).append(" "); } @@ -677,7 +689,7 @@ public class WordVectorSerializer { StringBuilder builder = new StringBuilder(); for (VocabWord word : vectors.getVocab().tokens()) { if (word.isLabel()) - builder.append(encodeB64(word.getLabel())).append("\n"); + builder.append(ReadHelper.encodeB64(word.getLabel())).append("\n"); } IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8); @@ -688,7 +700,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()) + builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()) .append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel())); writer.println(builder.toString().trim()); @@ -744,7 +756,7 @@ public class WordVectorSerializer { try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { - VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim())); + VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim())); if (word != null) { word.markAsLabel(true); } @@ -836,7 +848,7 @@ public class WordVectorSerializer { String line; while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0])); + VocabWord word = w2v.getVocab().tokenFor(ReadHelper.decodeB64(split[0])); word.setElementFrequency((long) Double.parseDouble(split[1])); word.setSequencesCount((long) Double.parseDouble(split[2])); } @@ -946,7 +958,7 @@ public class WordVectorSerializer { reader = new BufferedReader(new FileReader(h_points)); while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = vocab.wordFor(decodeB64(split[0])); + VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0])); List points = new ArrayList<>(); for (int i = 1; i < split.length; i++) { points.add(Integer.parseInt(split[i])); @@ -960,7 +972,7 @@ public class WordVectorSerializer { reader = new BufferedReader(new FileReader(h_codes)); while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = vocab.wordFor(decodeB64(split[0])); + VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0])); List codes = new ArrayList<>(); for (int i = 1; i < split.length; i++) { codes.add(Byte.parseByte(split[i])); @@ -1704,7 +1716,7 @@ public class WordVectorSerializer { if (line.isEmpty()) line = iter.nextLine(); String[] split = line.split(" "); - String word = decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " "); + String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " "); VocabWord word1 = new VocabWord(1.0, word); word1.setIndex(cache.numWords()); @@ -1994,7 +2006,13 @@ public class WordVectorSerializer { private static final String SYN1_ENTRY = "syn1.bin"; private static final String SYN1_NEG_ENTRY = "syn1neg.bin"; - + /** + * This method saves specified SequenceVectors model to target OutputStream + * + * @param vectors SequenceVectors model + * @param stream Target output stream + * @param + */ public static void writeSequenceVectors(@NonNull SequenceVectors vectors, @NonNull OutputStream stream) throws IOException { @@ -2040,7 +2058,13 @@ public class WordVectorSerializer { } } - + /** + * This method loads SequenceVectors from specified file path + * + * @param path String + * @param readExtendedTables boolean + * @param + */ public static SequenceVectors readSequenceVectors(@NonNull String path, boolean readExtendedTables) throws IOException { @@ -2050,6 +2074,14 @@ public class WordVectorSerializer { return vectors; } + /** + * This method loads SequenceVectors from specified file path + * + * @param file File + * @param readExtendedTables boolean + * @param + */ + public static SequenceVectors readSequenceVectors(@NonNull File file, boolean readExtendedTables) throws IOException { @@ -2058,6 +2090,13 @@ public class WordVectorSerializer { return vectors; } + /** + * This method loads SequenceVectors from specified input stream + * + * @param stream InputStream + * @param readExtendedTables boolean + * @param + */ public static SequenceVectors readSequenceVectors(@NonNull InputStream stream, boolean readExtendedTables) throws IOException { @@ -2381,6 +2420,12 @@ public class WordVectorSerializer { } } + /** + * This method loads Word2Vec model from binary file + * + * @param file File + * @return Word2Vec + */ public static Word2Vec readAsBinary(@NonNull File file) { boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive(); int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency(); @@ -2403,6 +2448,12 @@ public class WordVectorSerializer { } } + /** + * This method loads Word2Vec model from csv file + * + * @param file File + * @return Word2Vec + */ public static Word2Vec readAsCsv(@NonNull File file) { Word2Vec vec; @@ -2491,7 +2542,7 @@ public class WordVectorSerializer { String line; while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0])); + VocabWord word = new VocabWord(Double.valueOf(split[1]), ReadHelper.decodeB64(split[0])); word.setIndex(cnt.getAndIncrement()); word.incrementSequencesCount(Long.valueOf(split[2])); @@ -2669,7 +2720,7 @@ public class WordVectorSerializer { * * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment. * - * @param file File should point to previously saved w2v model + * @param inputStream InputStream should point to previously saved w2v model * @return */ public static WordVectors loadStaticModel(InputStream inputStream) throws IOException { @@ -2685,6 +2736,17 @@ public class WordVectorSerializer { } // TODO: this method needs better name :) + /** + * This method restores previously saved w2v model. File can be in one of the following formats: + * 1) Binary model, either compressed or not. Like well-known Google Model + * 2) Popular CSV word2vec text format + * 3) DL4j compressed format + * + * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment. + * + * @param file File + * @return + */ public static WordVectors loadStaticModel(@NonNull File file) { if (!file.exists() || file.isDirectory()) throw new RuntimeException( @@ -2843,8 +2905,8 @@ public class WordVectorSerializer { throw new RuntimeException(e); } try { - numWords = Integer.parseInt(readString(stream)); - vectorLength = Integer.parseInt(readString(stream)); + numWords = Integer.parseInt(ReadHelper.readString(stream)); + vectorLength = Integer.parseInt(ReadHelper.readString(stream)); } catch (IOException e) { throw new RuntimeException(e); } @@ -2858,13 +2920,13 @@ public class WordVectorSerializer { @Override public Pair next() { try { - String word = readString(stream); + String word = ReadHelper.readString(stream); VocabWord element = new VocabWord(1.0, word); element.setIndex(idxCounter.getAndIncrement()); float[] vector = new float[vectorLength]; for (int i = 0; i < vectorLength; i++) { - vector[i] = readFloat(stream); + vector[i] = ReadHelper.readFloat(stream); } return Pair.makePair(element, vector); @@ -2913,7 +2975,7 @@ public class WordVectorSerializer { String[] split = nextLine.split(" "); - VocabWord word = new VocabWord(1.0, decodeB64(split[0])); + VocabWord word = new VocabWord(1.0, ReadHelper.decodeB64(split[0])); word.setIndex(idxCounter.getAndIncrement()); float[] vector = new float[split.length - 1]; @@ -2937,26 +2999,12 @@ public class WordVectorSerializer { } } - public static String encodeB64(String word) { - try { - return "B64:" + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", ""); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public static String decodeB64(String word) { - if (word.startsWith("B64:")) { - String arp = word.replaceFirst("B64:", ""); - try { - return new String(Base64.decodeBase64(arp), "UTF-8"); - } catch (Exception e) { - throw new RuntimeException(e); - } - } else - return word; - } - + /** + * This method saves Word2Vec model to output stream + * + * @param word2Vec Word2Vec + * @param stream OutputStream + */ public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream) throws IOException { @@ -2968,6 +3016,13 @@ public class WordVectorSerializer { writeSequenceVectors(vectors, stream); } + /** + * This method restores Word2Vec model from file + * + * @param path String + * @param readExtendedTables booleab + * @return Word2Vec + */ public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) throws IOException { @@ -2976,6 +3031,12 @@ public class WordVectorSerializer { return word2Vec; } + /** + * This method saves table of weights to file + * + * @param weightLookupTable WeightLookupTable + * @param file File + */ public static void writeLookupTable(WeightLookupTable weightLookupTable, @NonNull File file) throws IOException { try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), @@ -3038,7 +3099,7 @@ public class WordVectorSerializer { headerRead = true; weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build(); } else { - String label = decodeB64(tokens[0]); + String label = ReadHelper.decodeB64(tokens[0]); int freq = Integer.parseInt(tokens[1]); int rows = Integer.parseInt(tokens[2]); int cols = Integer.parseInt(tokens[3]); @@ -3071,6 +3132,13 @@ public class WordVectorSerializer { return weightLookupTable; } + /** + * This method loads Word2Vec model from file + * + * @param file File + * @param readExtendedTables boolean + * @return Word2Vec + */ public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) throws IOException { @@ -3078,6 +3146,13 @@ public class WordVectorSerializer { return word2Vec; } + /** + * This method loads Word2Vec model from input stream + * + * @param stream InputStream + * @param readExtendedTable boolean + * @return Word2Vec + */ public static Word2Vec readWord2Vec(@NonNull InputStream stream, boolean readExtendedTable) throws IOException { SequenceVectors vectors = readSequenceVectors(stream, readExtendedTable); @@ -3087,7 +3162,13 @@ public class WordVectorSerializer { word2Vec.setModelUtils(vectors.getModelUtils()); return word2Vec; } - + + /** + * This method loads FastText model to file + * + * @param vectors FastText + * @param path File + */ public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException { ObjectOutputStream outputStream = null; try { @@ -3106,6 +3187,11 @@ public class WordVectorSerializer { } } + /** + * This method unloads FastText model from file + * + * @param path File + */ public static FastText readWordVectors(File path) { FastText result = null; try { @@ -3124,6 +3210,13 @@ public class WordVectorSerializer { return result; } + /** + * This method prints memory usage to log + * + * @param numWords + * @param vectorLength + * @param numTables + */ public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) { double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables; @@ -3144,4 +3237,102 @@ public class WordVectorSerializer { OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx); } + + /** + * Helper static methods to read data from input stream. + */ + public static class ReadHelper { + /** + * Read a float from a data input stream Credit to: + * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java + * + * @param is + * @return + * @throws IOException + */ + private static float readFloat(InputStream is) throws IOException { + byte[] bytes = new byte[4]; + is.read(bytes); + return getFloat(bytes); + } + + /** + * Read a string from a data input stream Credit to: + * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java + * + * @param b + * @return + * @throws IOException + */ + private static float getFloat(byte[] b) { + int accum = 0; + accum = accum | (b[0] & 0xff) << 0; + accum = accum | (b[1] & 0xff) << 8; + accum = accum | (b[2] & 0xff) << 16; + accum = accum | (b[3] & 0xff) << 24; + return Float.intBitsToFloat(accum); + } + + /** + * Read a string from a data input stream Credit to: + * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java + * + * @param dis + * @return + * @throws IOException + */ + private static String readString(DataInputStream dis) throws IOException { + byte[] bytes = new byte[MAX_SIZE]; + byte b = dis.readByte(); + int i = -1; + StringBuilder sb = new StringBuilder(); + while (b != 32 && b != 10) { + i++; + bytes[i] = b; + b = dis.readByte(); + if (i == 49) { + sb.append(new String(bytes, "UTF-8")); + i = -1; + bytes = new byte[MAX_SIZE]; + } + } + sb.append(new String(bytes, 0, i + 1, "UTF-8")); + return sb.toString(); + } + + private static final String B64 = "B64:"; + + /** + * Encode input string + * + * @param word String + * @return String + */ + public static String encodeB64(String word) { + try { + return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", ""); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Encode input string + * + * @param word String + * @return String + */ + + public static String decodeB64(String word) { + if (word.startsWith(B64)) { + String arp = word.replaceFirst(B64, ""); + try { + return new String(Base64.decodeBase64(arp), "UTF-8"); + } catch (Exception e) { + throw new RuntimeException(e); + } + } else + return word; + } + } } diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml index 0b6b05c26..7c9967ef8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml @@ -24,7 +24,7 @@ deeplearning4j-aws_2.11 DeepLearning4j-AWS - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml index 3fded3e4a..8a19b3b68 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml @@ -18,7 +18,7 @@ spark_2.11 org.deeplearning4j - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT 4.0.0 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml index a5aff014e..16c4ac298 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml @@ -18,7 +18,7 @@ spark_2.11 org.deeplearning4j - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT 4.0.0 dl4j-spark-nlp_2.11 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml index d8f425286..9192bb877 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml @@ -19,7 +19,7 @@ spark_2.11 org.deeplearning4j - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT 4.0.0 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml index 8b31872c5..d84947f1e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml @@ -18,7 +18,7 @@ spark_2.11 org.deeplearning4j - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT 4.0.0 dl4j-spark_2.11 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java index e6688a215..8c5188b70 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java @@ -17,7 +17,6 @@ package org.deeplearning4j.spark; import org.apache.spark.serializer.SerializerInstance; -import org.deeplearning4j.eval.*; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -28,6 +27,9 @@ import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.junit.Test; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.classification.*; +import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java index f8fe1f4f0..ecf9b937b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java @@ -19,7 +19,6 @@ package org.deeplearning4j.spark.impl.multilayer; import lombok.extern.slf4j.Slf4j; import org.apache.spark.api.java.JavaRDD; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -30,6 +29,7 @@ import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.junit.Test; +import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index ed56af9ee..abfd39060 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -29,15 +29,13 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.eval.ROC; -import org.deeplearning4j.eval.ROCMultiClass; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution; @@ -56,6 +54,9 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.ROC; +import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -63,6 +64,7 @@ import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; +import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.RmsProp; @@ -70,7 +72,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import scala.Tuple2; import java.io.File; -import java.nio.file.Files; import java.nio.file.Path; import java.util.*; @@ -121,11 +122,6 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0)); MultiLayerNetwork network2 = master.fitLabeledPoint(data); - Evaluation evaluation = new Evaluation(); - evaluation.eval(d.getLabels(), network2.output(d.getFeatures())); - System.out.println(evaluation.stats()); - - } @@ -137,20 +133,15 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { .getAbsolutePath()) .toJavaRDD().map(new TestFn()); - DataSet d = new IrisDataSetIterator(150, 150).next(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) - .miniBatch(true).maxNumLineSearchIterations(10) - .list().layer(0, - new DenseLayer.Builder().nIn(4).nOut(100) - .weightInit(WeightInit.XAVIER) - .activation(Activation.RELU) - .build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3) - .activation(Activation.SOFTMAX) - .weightInit(WeightInit.XAVIER).build()) + .updater(new Adam(1e-6)) + .weightInit(WeightInit.XAVIER) + .list() + .layer(new BatchNormalization.Builder().nIn(4).nOut(4).build()) + .layer(new DenseLayer.Builder().nIn(4).nOut(32).activation(Activation.RELU).build()) + .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(32).nOut(3) + .activation(Activation.SOFTMAX).build()) .build(); @@ -161,10 +152,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(), new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0)); - MultiLayerNetwork network2 = master.fitLabeledPoint(data); - Evaluation evaluation = new Evaluation(); - evaluation.eval(d.getLabels(), network2.output(d.getFeatures())); - System.out.println(evaluation.stats()); + master.fitLabeledPoint(data); } @Test(timeout = 120000L) @@ -465,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { tempDirF.deleteOnExit(); int dataSetObjSize = 1; - int batchSizePerExecutor = 25; - int numSplits = 10; + int batchSizePerExecutor = 16; + int numSplits = 5; int averagingFrequency = 3; int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency; DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml index f753fefae..bd7226b0e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml @@ -22,7 +22,7 @@ 4.0.0 spark_2.11 - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT pom Spark parent @@ -36,7 +36,7 @@ UTF-8 UTF-8 - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT 2.1.0 diff --git a/libnd4j/include/array/ConstantHolder.h b/libnd4j/include/array/ConstantHolder.h index 89be279e4..137d26f29 100644 --- a/libnd4j/include/array/ConstantHolder.h +++ b/libnd4j/include/array/ConstantHolder.h @@ -24,11 +24,13 @@ #include #include #include +#include namespace nd4j { class ConstantHolder { private: int _deviceId = 0; + std::mutex _mutex; std::map _buffers; public: @@ -53,6 +55,8 @@ namespace nd4j { template ConstantDataBuffer* getConstantDataBuffer(); + + std::mutex* mutex(); }; } diff --git a/libnd4j/include/array/impl/ConstantHolder.cpp b/libnd4j/include/array/impl/ConstantHolder.cpp index 92cc9df23..5913d57a9 100644 --- a/libnd4j/include/array/impl/ConstantHolder.cpp +++ b/libnd4j/include/array/impl/ConstantHolder.cpp @@ -16,6 +16,10 @@ namespace nd4j { return _buffers.count(dataType) > 0; } + std::mutex* ConstantHolder::mutex() { + return &_mutex; + } + template bool ConstantHolder::hasBuffer() { return hasBuffer(DataTypeUtils::fromT()); diff --git a/libnd4j/include/execution/cuda/AffinityManager.cu b/libnd4j/include/execution/cuda/AffinityManager.cu index 1f028b011..d28c0d6d0 100644 --- a/libnd4j/include/execution/cuda/AffinityManager.cu +++ b/libnd4j/include/execution/cuda/AffinityManager.cu @@ -47,7 +47,7 @@ namespace nd4j { _currentMutex.unlock(); - setCurrentDevice(globalThreadToDevice); + setCurrentNativeDevice(globalThreadToDevice); } // if we already know affinity - just return it @@ -92,6 +92,8 @@ namespace nd4j { void AffinityManager::setCurrentNativeDevice(int deviceId) { auto res = cudaSetDevice(deviceId); + if (res != 0) + throw cuda_exception::build("setCurrentDevice failed", res); } void AffinityManager::setCurrentDevice(int deviceId) { @@ -104,17 +106,22 @@ namespace nd4j { res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream()); if (res != 0) throw cuda_exception::build("setCurrentDevice -> specialSync failed", res); + + if (deviceId != previousDeviceId) { + // discard existing stuff + nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", ""); + LaunchContext::releaseBuffers(); + } } - auto res = cudaSetDevice(deviceId); - if (res != 0) - throw cuda_exception::build("cudaSetDevice failed", res); + if (deviceId != previousDeviceId) { + auto res = cudaSetDevice(deviceId); + if (res != 0) + throw cuda_exception::build("cudaSetDevice failed", res); + } // update thread-device affinity globalThreadToDevice = deviceId; - - // discard existing stuff - LaunchContext::releaseBuffers(); } std::atomic AffinityManager::_lastDevice;// = std::atomic(initialV); diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu index 895bb6623..435858462 100644 --- a/libnd4j/include/execution/cuda/ContextBuffers.cu +++ b/libnd4j/include/execution/cuda/ContextBuffers.cu @@ -107,7 +107,6 @@ namespace nd4j { ////// _allocated = false; - _initialized = false; _deviceId = -1; this->_specialStream = nullptr; @@ -116,6 +115,8 @@ namespace nd4j { this->_reductionPointer = nullptr; this->_scalarPointer = nullptr; } + + _initialized = false; } ContextBuffers::~ContextBuffers() { @@ -163,21 +164,21 @@ namespace nd4j { } void* ContextBuffers::reductionBuffer() { - if (_reductionPointer == nullptr) + if (!_initialized) initialize(); return _reductionPointer; } void* ContextBuffers::scalarBuffer() { - if (_scalarPointer == nullptr) + if (!_initialized) initialize(); return _scalarPointer; } void* ContextBuffers::allocationBuffer() { - if (_allocationPointer == nullptr) + if (!_initialized) initialize(); return _allocationPointer; @@ -204,15 +205,23 @@ namespace nd4j { } void* ContextBuffers::execStream() { - if (_execStream == nullptr) + if (!_initialized) { + //nd4j_printf("execStream not initialized\n", ""); initialize(); + } else { + //nd4j_printf("execStream is initialized\n", ""); + } return _execStream; } void* ContextBuffers::specialStream() { - if (_specialStream == nullptr) + if (!_initialized) { + //nd4j_printf("specialStream not initialized\n", ""); initialize(); + } else { + //nd4j_printf("specialStream is initialized\n", ""); + } return _specialStream; } diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu index 9d9f2c506..7d1691982 100644 --- a/libnd4j/include/execution/cuda/LaunchContext.cu +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -57,10 +57,6 @@ LaunchContext::LaunchContext() { _deviceID = 0; _isAllocated = true; - - _cublasHandle = CublasHelper::getInstance()->handle(); - - _cusolverHandle = CublasHelper::getInstance()->solver(); } LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { @@ -89,13 +85,13 @@ LaunchContext::LaunchContext() { _contexts.resize(numDevices); for (int e = 0; e < numDevices; e++) { - AffinityManager::setCurrentDevice(e); + AffinityManager::setCurrentNativeDevice(e); LaunchContext::_contexts[e] = std::make_shared(); } // don't forget to restore device back again - AffinityManager::setCurrentDevice(deviceId); + AffinityManager::setCurrentNativeDevice(deviceId); } _mutex.unlock(); @@ -117,11 +113,11 @@ LaunchContext::LaunchContext() { }; void* LaunchContext::getCublasHandle() const { - return _cublasHandle; + return CublasHelper::getInstance()->handle(); }; void* LaunchContext::getCusolverHandle() const { - return _cusolverHandle; + return CublasHelper::getInstance()->solver(); }; cudaStream_t* LaunchContext::getCudaStream() const { @@ -162,6 +158,7 @@ LaunchContext::LaunchContext() { }; void LaunchContext::releaseBuffers() { + nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", ""); contextBuffers.release(); } diff --git a/libnd4j/include/helpers/ConstantHelper.h b/libnd4j/include/helpers/ConstantHelper.h index a7f7d0c00..6aad7c387 100644 --- a/libnd4j/include/helpers/ConstantHelper.h +++ b/libnd4j/include/helpers/ConstantHelper.h @@ -38,12 +38,13 @@ namespace nd4j { static ConstantHelper* _INSTANCE; ConstantHelper(); - std::vector> _cache; + std::vector> _cache; // tracking of per-device constant memory buffers (CUDA only atm) std::vector _devicePointers; std::vector _deviceOffsets; std::mutex _mutex; + std::mutex _mutexHolder; std::vector _counters; public: diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h index fe0e52ce5..585db0198 100644 --- a/libnd4j/include/helpers/ConstantShapeHelper.h +++ b/libnd4j/include/helpers/ConstantShapeHelper.h @@ -48,10 +48,10 @@ namespace nd4j { static ConstantShapeHelper* getInstance(); - ConstantDataBuffer& bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape); - ConstantDataBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor); - ConstantDataBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo); - ConstantDataBuffer& bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape); + ConstantDataBuffer bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape); + ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor); + ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo); + ConstantDataBuffer bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape); Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType); diff --git a/libnd4j/include/helpers/ConstantTadHelper.h b/libnd4j/include/helpers/ConstantTadHelper.h index d2790998b..79ee7dcd4 100644 --- a/libnd4j/include/helpers/ConstantTadHelper.h +++ b/libnd4j/include/helpers/ConstantTadHelper.h @@ -54,11 +54,11 @@ namespace nd4j { * @param keepUnitiesInShape * @return */ - TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape = false); - TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); - TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); - TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape = false); - TadPack& tadForDimensions(TadDescriptor &descriptor); + TadPack tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape = false); + TadPack tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); + TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); + TadPack tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape = false); + TadPack tadForDimensions(TadDescriptor &descriptor); /** * This method returns number of cached TAD shapes/offsets on specific device diff --git a/libnd4j/include/helpers/cpu/ConstantHelper.cpp b/libnd4j/include/helpers/cpu/ConstantHelper.cpp index 43a4f97c1..b2549e93f 100644 --- a/libnd4j/include/helpers/cpu/ConstantHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantHelper.cpp @@ -33,7 +33,8 @@ namespace nd4j { _cache.resize(numDevices); _counters.resize(numDevices); for (int e = 0; e < numDevices; e++) { - std::map map; + std::map map; + _cache[e] = map; _counters[e] = 0L; } @@ -70,15 +71,26 @@ namespace nd4j { ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { const auto deviceId = getCurrentDevice(); + // we're locking away cache modification + _mutexHolder.lock(); + if (_cache[deviceId].count(descriptor) == 0) { - ConstantHolder holder; - _cache[deviceId][descriptor] = holder; + _cache[deviceId][descriptor] = new ConstantHolder(); } - ConstantHolder* holder = &_cache[deviceId][descriptor]; + auto holder = _cache[deviceId][descriptor]; + + // releasing cache lock + _mutexHolder.unlock(); + + + ConstantDataBuffer* result; + + // access to this holder instance is synchronous + holder->mutex()->lock(); if (holder->hasBuffer(dataType)) - return holder->getConstantDataBuffer(dataType); + result = holder->getConstantDataBuffer(dataType); else { auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType); auto cbuff = new int8_t[size]; @@ -94,8 +106,11 @@ namespace nd4j { ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType)); holder->addBuffer(dataBuffer, dataType); - return holder->getConstantDataBuffer(dataType); + result = holder->getConstantDataBuffer(dataType); } + holder->mutex()->unlock(); + + return result; } Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { diff --git a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp index bdb77ccaa..531b68004 100644 --- a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp @@ -41,18 +41,18 @@ namespace nd4j { return _INSTANCE; } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape) { ShapeDescriptor descriptor(dataType, order, shape); return bufferForShapeInfo(descriptor); } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { ShapeDescriptor descriptor(dataType, order, shape, rank); return bufferForShapeInfo(descriptor); } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { int deviceId = 0; _mutex.lock(); @@ -62,19 +62,19 @@ namespace nd4j { ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64); ShapeDescriptor descriptor1(descriptor); _cache[deviceId][descriptor1] = buffer; - ConstantDataBuffer &r = _cache[deviceId][descriptor1]; + auto r = _cache[deviceId][descriptor1]; _mutex.unlock(); return r; } else { - ConstantDataBuffer &r = _cache[deviceId].at(descriptor); + auto r = _cache[deviceId].at(descriptor); _mutex.unlock(); return r; } } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { ShapeDescriptor descriptor(shapeInfo); return bufferForShapeInfo(descriptor); } diff --git a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp index 5100ca3ff..822b5ad0d 100644 --- a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp @@ -38,25 +38,25 @@ namespace nd4j { return _INSTANCE; } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { return tadForDimensions(originalShape, const_cast(dimensions.data()), dimensions.size(), keepUnitiesInShape); } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); return tadForDimensions(tadDescriptor); } - TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); return tadForDimensions(tadDescriptor); } - TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { + TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { const int deviceId = 0; _mutex.lock(); @@ -105,7 +105,7 @@ namespace nd4j { return r; } else { - TadPack &r = _cache[deviceId][descriptor]; + TadPack r = _cache[deviceId][descriptor]; _mutex.unlock(); return r; diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h index d4f92881e..94cd2446b 100644 --- a/libnd4j/include/helpers/cublasHelper.h +++ b/libnd4j/include/helpers/cublasHelper.h @@ -24,11 +24,13 @@ #include #include #include +#include namespace nd4j { class CublasHelper { private: static CublasHelper *_INSTANCE; + static std::mutex _mutex; std::vector _cache; std::vector _solvers; diff --git a/libnd4j/include/helpers/cuda/ConstantHelper.cu b/libnd4j/include/helpers/cuda/ConstantHelper.cu index 0c7f2cbc1..0d7bdf64c 100644 --- a/libnd4j/include/helpers/cuda/ConstantHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantHelper.cu @@ -68,7 +68,7 @@ namespace nd4j { throw cuda_exception::build("cudaSetDevice failed", res); auto constant = getConstantSpace(); - std::map devCache; + std::map devCache; _devicePointers[e] = constant; _deviceOffsets[e] = 0; @@ -136,15 +136,24 @@ namespace nd4j { ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { const auto deviceId = getCurrentDevice(); - if (_cache[deviceId].count(descriptor) == 0) { - ConstantHolder holder; - _cache[deviceId][descriptor] = holder; - } + // all cache modifications are synchronous + _mutexHolder.lock(); - ConstantHolder* holder = &_cache[deviceId][descriptor]; + if (_cache[deviceId].count(descriptor) == 0) { + _cache[deviceId][descriptor] = new ConstantHolder(); + } + auto holder = _cache[deviceId][descriptor]; + + // release cache lock + _mutexHolder.unlock(); + + ConstantDataBuffer* result; + + // access to this holder instance is synchronous + holder->mutex()->lock(); if (holder->hasBuffer(dataType)) { - return holder->getConstantDataBuffer(dataType); + result = holder->getConstantDataBuffer(dataType); } else { auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType); auto cbuff = new int8_t[numBytes]; @@ -160,10 +169,14 @@ namespace nd4j { auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType)); ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType)); - holder->addBuffer(dataBuffer, dataType); - return holder->getConstantDataBuffer(dataType); + holder->addBuffer(dataBuffer, dataType); + result = holder->getConstantDataBuffer(dataType); } + // release holder lock + holder->mutex()->unlock(); + + return result; } Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu index a1217e0e3..4004b9895 100644 --- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu @@ -44,17 +44,17 @@ namespace nd4j { return _INSTANCE; } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape) { ShapeDescriptor descriptor(dataType, order, shape); return bufferForShapeInfo(descriptor); } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { ShapeDescriptor descriptor(dataType, order, shape, rank); return bufferForShapeInfo(descriptor); } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { int deviceId = AffinityManager::currentDeviceId(); _mutex.lock(); @@ -65,19 +65,19 @@ namespace nd4j { ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64); ShapeDescriptor descriptor1(descriptor); _cache[deviceId][descriptor1] = buffer; - ConstantDataBuffer &r = _cache[deviceId][descriptor1]; + auto r = _cache[deviceId][descriptor1]; _mutex.unlock(); return r; } else { - ConstantDataBuffer &r = _cache[deviceId].at(descriptor); + ConstantDataBuffer r = _cache[deviceId].at(descriptor); _mutex.unlock(); return r; } } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { ShapeDescriptor descriptor(shapeInfo); return bufferForShapeInfo(descriptor); } diff --git a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu index da66975c3..8ea4067f3 100644 --- a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu @@ -43,25 +43,25 @@ namespace nd4j { return _INSTANCE; } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { return tadForDimensions(originalShape, const_cast(dimensions.data()), dimensions.size(), keepUnitiesInShape); } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); return tadForDimensions(tadDescriptor); } - TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); return tadForDimensions(tadDescriptor); } - TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { + TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { const int deviceId = AffinityManager::currentDeviceId(); _mutex.lock(); @@ -96,14 +96,14 @@ namespace nd4j { TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); _cache[deviceId][descriptor] = t; - TadPack &r = _cache[deviceId][descriptor]; + TadPack r = _cache[deviceId][descriptor]; _mutex.unlock(); delete[] shapeInfo; return r; } else { - TadPack &r = _cache[deviceId][descriptor]; + TadPack r = _cache[deviceId][descriptor]; _mutex.unlock(); return r; diff --git a/libnd4j/include/helpers/cuda_off/cublasHelper.cu b/libnd4j/include/helpers/cuda_off/cublasHelper.cu index 6f2cf2084..d9784eaa2 100644 --- a/libnd4j/include/helpers/cuda_off/cublasHelper.cu +++ b/libnd4j/include/helpers/cuda_off/cublasHelper.cu @@ -27,6 +27,7 @@ #include namespace nd4j { + std::mutex CublasHelper::_mutex; static void* handle_() { auto _handle = new cublasHandle_t(); @@ -56,22 +57,24 @@ namespace nd4j { } CublasHelper::CublasHelper() { + //nd4j_printf("Initializing cuBLAS\n",""); auto numDevices = AffinityManager::numberOfDevices(); auto currentDevice = AffinityManager::currentDeviceId(); _cache.resize(numDevices); _solvers.resize(numDevices); for (int e = 0; e < numDevices; e++) { - AffinityManager::setCurrentDevice(e); + AffinityManager::setCurrentNativeDevice(e); _cache[e] = handle_(); _solvers[e] = solver_(); } // don't forget to restore back original device - AffinityManager::setCurrentDevice(currentDevice); + AffinityManager::setCurrentNativeDevice(currentDevice); } CublasHelper::~CublasHelper() { + nd4j_printf("Releasing cuBLAS\n",""); auto numDevices = AffinityManager::numberOfDevices(); for (int e = 0; e < numDevices; e++) @@ -79,8 +82,10 @@ namespace nd4j { } CublasHelper* CublasHelper::getInstance() { + _mutex.lock(); if (!_INSTANCE) _INSTANCE = new nd4j::CublasHelper(); + _mutex.unlock(); return _INSTANCE; } diff --git a/libnd4j/include/loops/cpu/random.cpp b/libnd4j/include/loops/cpu/random.cpp index 889e48181..30bab1327 100644 --- a/libnd4j/include/loops/cpu/random.cpp +++ b/libnd4j/include/loops/cpu/random.cpp @@ -162,9 +162,6 @@ namespace functions { } } } - - // update rng state - rng->rewindH(length); }; @@ -223,8 +220,6 @@ namespace functions { } } } - // update rng state - rng->rewindH(length); } @@ -256,9 +251,6 @@ namespace functions { z[offset] = OpClass::op(i+threadOffset, length, rng, extraArguments); } } - - // update rng state - rng->rewindH(length); } template diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrixSetDiag.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/matrixSetDiag.cpp index f63469817..3a52057a5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/matrixSetDiag.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/matrixSetDiag.cpp @@ -15,7 +15,7 @@ ******************************************************************************/ // -// @author Yurii Shyrma (iuriish@yahoo.com), created on 07.12.2017 +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -38,10 +38,9 @@ CONFIGURABLE_OP_IMPL(matrix_set_diag, 2, 1, false, 0, 0) { for(int i = 0; i < diagonal->rankOf() - 1; ++i) REQUIRE_TRUE(diagonal->sizeAt(i) == input->sizeAt(i), 0, "MATRIX_SET_DIAG op: the shapes of diagonal and input arrays must be equal till last diagonal dimension but one, however got diagonal=%s and input=%s instead !", ShapeUtils::shapeAsString(diagonal).c_str(), ShapeUtils::shapeAsString(input).c_str()); - REQUIRE_TRUE(diagonal->sizeAt(-1) == (int)nd4j::math::nd4j_min(input->sizeAt(-1), input->sizeAt(-2)), - 0, "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but got %i instead !", input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1)); + REQUIRE_TRUE(diagonal->sizeAt(-1) == (int)nd4j::math::nd4j_min(input->sizeAt(-1), input->sizeAt(-2)), 0, "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but got %i instead !", input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1)); - helpers::matrixSetDiag(block.launchContext(), input, diagonal, output); + helpers::matrixSetDiag(block.launchContext(), *input, *diagonal, *output, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrix_diag.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/matrix_diag.cpp index 8fa5bfa41..c430fd4d2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/matrix_diag.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/matrix_diag.cpp @@ -15,49 +15,53 @@ ******************************************************************************/ // -// Created to use with batched tensor by GS 3/21/2018 +// @author GS 3/21/2018 +// @author Yurii Shyrma (iuriish@yahoo.com) // #include -#include - +#include namespace nd4j { - namespace ops { - CUSTOM_OP_IMPL(matrix_diag, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { - REQUIRE_TRUE(!input->isScalar(), 0, "CUSTOM_OP matrix_diag: input array must be at list a vector, but scalar was given!"); +CUSTOM_OP_IMPL(matrix_diag, 1, 1, false, 0, 0) { - output->nullify(); - return helpers::matrixDiag(block.launchContext(), input, output); - } + auto diagonal = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - DECLARE_SHAPE_FN(matrix_diag) { - Nd4jLong* outShapeInfo = nullptr; - auto in = inputShape->at(0); - int inRank = shape::rank(in); + REQUIRE_TRUE(!diagonal->isScalar(), 0, "CUSTOM_OP matrix_diag: input diagonal array must be at list a vector, but scalar was given!"); - int outRank = inRank + 1; - auto lastDimension = shape::sizeAt(in, -1); + helpers::matrixSetDiag(block.launchContext(), *output, *diagonal, *output, true); - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - outShapeInfo[0] = outRank; - for(int i = 0; i < inRank; ++i) - outShapeInfo[i + 1] = shape::sizeAt(in, i); - outShapeInfo[outRank] = lastDimension; + return Status::OK(); +} - ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); +DECLARE_SHAPE_FN(matrix_diag) { - return SHAPELIST(CONSTANT(outShapeInfo)); - } + Nd4jLong* outShapeInfo = nullptr; + auto in = inputShape->at(0); + int inRank = shape::rank(in); - DECLARE_TYPES(matrix_diag) { - getOpDescriptor() - ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setSameMode(true); - } + int outRank = inRank + 1; + auto lastDimension = shape::sizeAt(in, -1); + + ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + outShapeInfo[0] = outRank; + for(int i = 0; i < inRank; ++i) + outShapeInfo[i + 1] = shape::sizeAt(in, i); + outShapeInfo[outRank] = lastDimension; + + ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outShapeInfo)); +} + +DECLARE_TYPES(matrix_diag) { + getOpDescriptor() + ->setAllowedInputTypes(nd4j::DataType::ANY) + ->setSameMode(true); +} } } diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index f9278fb36..c86f28499 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -76,8 +76,20 @@ namespace nd4j { #endif /** - * Returns a batched matrix tensor with new batched diagonal values. - */ + * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array + * + * Input arrays: + * input: input array, considered as batch of matrices + * diagonal: array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank - 1, + * the shapes of diagonal and input arrays must be equal except last dimension of input array, + * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], + * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions + * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * + * Output array: + * has the same shape as input, corresponding diagonal elements are substituted + */ #if NOT_EXCLUDED(OP_matrix_set_diag) DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0); #endif diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index dd5516461..3d04bc129 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -2411,7 +2411,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d( for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0 - (T)1.f); + pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0 - (T)1.f) * nd4j::math::nd4j_sgn(pIn[kd + kh + kw]); } else { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp index 7180a88b3..e974755ac 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp @@ -15,7 +15,7 @@ ******************************************************************************/ // -// Created by Yurii Shyrma on 07.12.2017. +// @author Yurii Shyrma (iuriish@yahoo.com) // #include "ResultSet.h" @@ -27,31 +27,48 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////// -// Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag -template -static void _matrixSetDiag(const NDArray* input, const NDArray* diagonal, NDArray* output) { +template +void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { - *output = *input; + // input and output are the same array (x == z) when zeroPad = true + // xRank = zRank, xRank = yRank + 1 + // xLen = zLen - const int lastDimSize = input->sizeAt(-1); - const int last2DimSize = input->sizeAt(-1) * input->sizeAt(-2); - const int lastSmallDim = diagonal->sizeAt(-1); - const int batchSize = input->lengthOf()/last2DimSize; + const T* x = input.bufferAsT(); + const T* y = diagonal.bufferAsT(); + T* z = output.bufferAsT(); - for(int i = 0; i < batchSize; ++i ) - for(int j = 0; j < lastSmallDim; ++j) { - output->p(i*last2DimSize + j*(lastDimSize + 1), diagonal->e(i*lastSmallDim + j)); - } - + const Nd4jLong* xShapeInfo = input.getShapeInfo(); + const Nd4jLong* yShapeInfo = diagonal.getShapeInfo(); + const Nd4jLong* zShapeInfo = output.getShapeInfo(); + const bool areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not + + const int xRank = input.rankOf(); + const auto xLen = input.lengthOf(); + + std::vector coords(xRank); // we use the same coordinates storage both for input and output since their ranks are the same + + PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(coords)) + for (Nd4jLong i = 0; i < xLen; ++i) { + + shape::index2coords(xRank, xShapeInfo + 1, i, xLen, coords.data()); + + const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords.data(), xRank); + const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords.data(), xRank); + + // condition to be on diagonal of innermost matrix + if(coords[xRank - 2] == coords[xRank - 1]) + z[zOffset] = y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank, coords.data(), xRank - 1)]; + else + z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; + } } - void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), _matrixSetDiag, (input, diagonal, output), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void _matrixSetDiag, (const NDArray* input, const NDArray* diagonal, NDArray* output), LIBND4J_TYPES); +////////////////////////////////////////////////////////////////////////// +void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { + BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiag_, (input, diagonal, output, zeroPad), LIBND4J_TYPES); +} } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag.cpp deleted file mode 100644 index 3f9883b54..000000000 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag.cpp +++ /dev/null @@ -1,65 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by GS on 3/21/2018. -// - -#include "ResultSet.h" -#include -#include - -namespace nd4j { -namespace ops { -namespace helpers { - - -////////////////////////////////////////////////////////////////////////// -// Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag -template -static int _matrixDiag(const NDArray* input, NDArray* output) { - - auto listOut = output->allTensorsAlongDimension({output->rankOf() - 2, output->rankOf() - 1}); - auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 1}); - - if (listOut->size() != listDiag->size()) { - nd4j_printf("matrix_diag: Input matrix has wrong shape.", ""); - return ND4J_STATUS_VALIDATION; - } - int lastDimension = input->sizeAt(-1); - // TODO: tune this properlys - int lO = listOut->size(); - PRAGMA_OMP_PARALLEL_FOR_IF(lO > Environment::getInstance()->tadThreshold()) - for(int i = 0; i < lO; ++i) - for (int e = 0; e < lastDimension; e++) - listOut->at(i)->p(e, e, listDiag->at(i)->e(e)); - - delete listOut; - delete listDiag; - - return Status::OK(); -} - - int matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiag, (input, output), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template int _matrixDiag, (const NDArray* input, NDArray* output), LIBND4J_TYPES); - -} -} -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 87e7c4f08..c08551318 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -957,9 +957,13 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf val *= nd4j::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) - nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], val * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)]), extraParam0 - 1.f)); + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) { + const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank); + const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); + nd4j::math::atomics::nd4j_atomicAdd(&z[zOffset], val * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn(x[xOffset])); + } + } } break; } @@ -1123,10 +1127,15 @@ __global__ static void pooling3dBPCuda(const void* vx, const Nd4jLong* xShapeInf val *= nd4j::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], val * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)]), extraParam0 - 1.f)); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank); + const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); + nd4j::math::atomics::nd4j_atomicAdd(&z[zOffset], val * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn(x[xOffset])); + } + } + } } break; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu index 95eb5f439..01baaffb4 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu @@ -15,63 +15,87 @@ ******************************************************************************/ // -// Created by Yurii Shyrma on 07.12.2017. +// @author Yurii Shyrma (iuriish@yahoo.com) // #include "ResultSet.h" #include +#include -namespace nd4j { -namespace ops { +namespace nd4j { +namespace ops { namespace helpers { +/////////////////////////////////////////////////////////////////// +template +__global__ static void matrixSetDiagCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) { - template - static __global__ void matrixSetDiagKernel(void* outputBuffer, Nd4jLong* outputShape, void const* diagonalBuffer, Nd4jLong* diagonalShape, Nd4jLong lastDimSize, Nd4jLong last2DimSize, Nd4jLong lastSmallDim, Nd4jLong batchSize) { - __shared__ T* z; - __shared__ T const* x; - __shared__ Nd4jLong outLength, diagonalLen; - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - x = reinterpret_cast(diagonalBuffer); - outLength = shape::length(outputShape); - diagonalLen = shape::length(diagonalShape); - } - __syncthreads(); + // x - input, shape [A,B,C] + // y - diagonal, shape [A,B] + // z - output, shape [A,B,C] + // input and output are the same array (x == z) when zeroPad = true - for(int i = blockIdx.x; i < batchSize; i+= gridDim.x ) - for(int j = threadIdx.x; j < lastSmallDim; j += blockDim.x) { -// z[i * last2DimSize + j * (lastDimSize + 1)] = x[i * lastSmallDim + j]; - z[shape::getIndexOffset(i * last2DimSize + j * (lastDimSize + 1), outputShape, outLength)] = x[shape::getIndexOffset(i * lastSmallDim + j, diagonalShape, diagonalLen)]; - } - } - ////////////////////////////////////////////////////////////////////////// - // Returns a batched matrix tensor with new batched diagonal values. - // for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag - template - static void _matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) { - *output = *input; + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - const int lastDimSize = input->sizeAt(-1); - const int last2DimSize = input->sizeAt(-1) * input->sizeAt(-2); - const int lastSmallDim = diagonal->sizeAt(-1); - const int batchSize = input->lengthOf()/last2DimSize; - auto stream = context->getCudaStream(); - dim3 launchDims(256, 512, 8192); - matrixSetDiagKernel<<>>(output->specialBuffer(), output->specialShapeInfo(), diagonal->getSpecialBuffer(), diagonal->getSpecialShapeInfo(), lastDimSize, last2DimSize, lastSmallDim, batchSize); -//// #pragma omp parallel for if(batchSize > Environment::getInstance()->elementwiseThreshold()) schedule(static) -// for(int i = 0; i < batchSize; ++i ) -// for(int j = 0; j < lastSmallDim; ++j) { -// output->p(i*last2DimSize + j*(lastDimSize + 1), diagonal->e(i*lastSmallDim + j)); -// } + __shared__ int xRank; // xRank = zRank, xRank = yRank + 1 + __shared__ Nd4jLong xLen, *sharedMem; // xLen = zLen + __shared__ bool areSameOffsets; + if (threadIdx.x == 0) { + + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not + + xRank = shape::rank(xShapeInfo); + xLen = shape::length(xShapeInfo); } - void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), _matrixSetDiag, (context, input, diagonal, output), LIBND4J_TYPES); - } + __syncthreads(); - BUILD_SINGLE_TEMPLATE(template void _matrixSetDiag, (nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output), LIBND4J_TYPES); + auto coords = sharedMem + threadIdx.x * xRank; // we provide (xRank * sizeof(Nd4jLong) * threadIdx.x) amount of shared memory per each thread + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < xLen; i += gridDim.x * blockDim.x) { + + shape::index2coords(xRank, xShapeInfo + 1, i, xLen, coords); + + const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords, xRank); + const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords, xRank); + + // condition to be on diagonal of innermost matrix + if(coords[xRank - 2] == coords[xRank - 1]) + z[zOffset] = y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank, coords, xRank - 1)]; + else + z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; + } +} + +/////////////////////////////////////////////////////////////////// +template +static void matrixSetDiagCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) { + + matrixSetDiagCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, zeroPad); +} + +/////////////////////////////////////////////////////////////////// +void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * input.rankOf() + 128; + + PointersManager manager(context, "matrixSetDiag"); + + NDArray::prepareSpecialUse({&output}, {&input, &diagonal}); + BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiagCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), diagonal.getSpecialBuffer(), diagonal.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), zeroPad), LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&input, &diagonal}); + + manager.synchronize(); +} } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag.cu deleted file mode 100644 index 78304510d..000000000 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag.cu +++ /dev/null @@ -1,95 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by GS on 3/21/2018. -// - -#include "ResultSet.h" -#include -#include -#include -#include -#include -#include -#include - -namespace nd4j { -namespace ops { -namespace helpers { - - - template - static __global__ void matrixDiagKernel(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, - Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets, - Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets) { - int totalThreads = blockDim.x; - for (Nd4jLong i = blockIdx.x; i < numTads; i += gridDim.x) { - auto yOffset = tadInputOffsets[i]; - auto xOffset = tadOutputOffsets[i]; - for (Nd4jLong j = threadIdx.x; j < inputLength; j += totalThreads) { - Nd4jLong coords[2] = {j, j}; - Nd4jLong tadOffset = shape::getOffset(0, shape::shapeOf(tadOnlyOutputShapeInfo), shape::stride(tadOnlyOutputShapeInfo), coords, 2); - //shape::getIndexOffset(j, tadOnlyOutputShapeInfo, inputLength) - *(reinterpret_cast(outputBuffer) + xOffset + tadOffset) = *(reinterpret_cast(inputBuffer) + yOffset + shape::getIndexOffset(j, tadOnlyInputShapeInfo, inputLength)); - } - } - } - ////////////////////////////////////////////////////////////////////////// - // Returns a batched matrix tensor with new batched diagonal values. - // for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag - - template - static int _matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) { - cudaStream_t* stream = context->getCudaStream(); - //auto listOut = output->allTensorsAlongDimension({output->rankOf() - 2, output->rankOf() - 1}); - //auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 1}); - - //auto repeatDelta = shape::prodLong(newShape.data(), rank) / this->lengthOf(); - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), {input->rankOf() - 1}); - const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs(input->getShapeInfo(), dimsToExclude); //this->tensorsAlongDimension({dimension}); - //printf("Repeat delta %lld, numTads %lld\n", repeatDelta, numTads); - //tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets; - std::vector inputDims({input->rankOf() - 1}); - std::vector outputDims({output->rankOf() - 2, output->rankOf() - 1}); - - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), inputDims); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), outputDims); - - if (!input->isActualOnDeviceSide()) - input->syncToDevice(); - - if (!output->isActualOnDeviceSide()) - output->syncToDevice(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - - dim3 launchDims(256, 512, 8192); - matrixDiagKernel<<>>(input->getSpecialBuffer(), output->getSpecialBuffer(), numTads, input->sizeAt(-1), packX.specialShapeInfo(), packX.specialOffsets(), packZ.specialShapeInfo(), packZ.specialOffsets()); - - return Status::OK(); - } - - int matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiag, (context, input, output), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template int _matrixDiag, (nd4j::LaunchContext * context, const NDArray* input, NDArray* output), LIBND4J_TYPES); - -} -} -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h b/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h index ea5a1a4ad..fb7d57d18 100644 --- a/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h +++ b/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h @@ -28,8 +28,7 @@ namespace nd4j { namespace ops { namespace helpers { - void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output); - + void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad); } } diff --git a/libnd4j/include/ops/declarable/helpers/matrix_diag.h b/libnd4j/include/ops/declarable/helpers/matrix_diag.h deleted file mode 100644 index 0cbbcef16..000000000 --- a/libnd4j/include/ops/declarable/helpers/matrix_diag.h +++ /dev/null @@ -1,34 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author GS -// -#ifndef __MATRIX_DIAG_HELPERS__ -#define __MATRIX_DIAG_HELPERS__ -#include -#include - -namespace nd4j { -namespace ops { -namespace helpers { - - int matrixDiag(nd4j::LaunchContext * context, NDArray const* input, NDArray* output); - -} -} -} -#endif diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 1ec9650f9..7d166f831 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -117,9 +117,9 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) { auto v = result->at(0); auto i = result->at(1); - v->printIndexedBuffer("Values"); - i->printIndexedBuffer("Indices"); - i->printShapeInfo("Indices shape"); + // v->printIndexedBuffer("Values"); + // i->printIndexedBuffer("Indices"); + // i->printShapeInfo("Indices shape"); ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.equalsTo(v)); @@ -145,12 +145,12 @@ TEST_F(DeclarableOpsTests3, Test_Unique_2) { auto i = result->at(1); auto c = result->at(2); - v->printShapeInfo(); - v->printIndexedBuffer("Values"); - i->printShapeInfo(); - i->printIndexedBuffer("Indices"); - c->printShapeInfo(); - c->printIndexedBuffer("Counts"); + // v->printShapeInfo(); + // v->printIndexedBuffer("Values"); + // i->printShapeInfo(); + // i->printIndexedBuffer("Indices"); + // c->printShapeInfo(); + // c->printIndexedBuffer("Counts"); ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.equalsTo(v)); @@ -200,11 +200,11 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { auto result1 = op.execute({&x}, {1.}, {1}); ASSERT_EQ(result1->status(), ND4J_STATUS_OK); auto z1 = result1->at(0); - z1->printIndexedBuffer("Z1"); + // z1->printIndexedBuffer("Z1"); auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false); - exp1.printIndexedBuffer("EXP1"); - z1->printShapeInfo("Z1 shape"); - exp1.printShapeInfo("EXP1 shape"); + // exp1.printIndexedBuffer("EXP1"); + // z1->printShapeInfo("Z1 shape"); + // exp1.printShapeInfo("EXP1 shape"); ASSERT_TRUE(exp1.isSameShape(z1)); ASSERT_TRUE(exp1.equalsTo(z1)); @@ -714,7 +714,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { auto exp = MmulHelper::mmul(&x, &y); - exp->printShapeInfo("exp shape"); + // exp->printShapeInfo("exp shape"); nd4j::ops::batched_gemm op; auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); diff --git a/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu b/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu index 49c1f7a95..6913722be 100644 --- a/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu @@ -79,7 +79,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_2) { sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true); k.tickWriteDevice(); v.tickWriteDevice(); - k.printIndexedBuffer("KEYS"); + // k.printIndexedBuffer("KEYS"); ASSERT_EQ(ek, k); ASSERT_EQ(ev, v); } @@ -98,8 +98,8 @@ TEST_F(SortCudaTests, test_tad_sort_by_key_1) { k.tickWriteDevice(); v.tickWriteDevice(); - k.printIndexedBuffer("k"); - v.printIndexedBuffer("v"); + // k.printIndexedBuffer("k"); + // v.printIndexedBuffer("v"); ASSERT_EQ(ek, k); ASSERT_EQ(ev, v); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 49e760961..ac017beef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1562,8 +1562,8 @@ public class DifferentialFunctionFactory { } - public SDVariable eluBp(SDVariable in, SDVariable epsilon) { - return new EluBp(sameDiff(), in, epsilon).outputVariable(); + public SDVariable eluBp(SDVariable in, SDVariable epsilon, double alpha) { + return new EluBp(sameDiff(), in, epsilon, alpha).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java index b7ac3887c..b714b1f06 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java @@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; -import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.BooleanIndexing; -import org.nd4j.linalg.indexing.conditions.Conditions; +import org.nd4j.linalg.primitives.Pair; /** * f(x) = alpha * (exp(x) - 1.0); x < 0 @@ -55,15 +53,7 @@ public class ActivationELU extends BaseActivationFunction { */ @Override public INDArray getActivation(INDArray in, boolean training) { - // no support in ELU native to override alpha - if (this.alpha != 1.00) { - INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0]; - alphaMultiple.muli(alpha); - BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0)); - } else { - Nd4j.getExecutioner().execAndReturn(new ELU(in)); - } - return in; + return Nd4j.exec(new ELU(in, in, alpha))[0]; } /* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 0d0af0788..ac642872c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -1195,7 +1195,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - return this; } @@ -3089,12 +3088,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return mmuli(other, result); } - /** - * in place (element wise) division of two matrices - * - * @param other the second ndarray to divide - * @return the result of the divide - */ @Override public INDArray div(INDArray other) { if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { @@ -3104,25 +3097,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * copy (element wise) division of two matrices - * - * @param other the second ndarray to divide - * @param result the result ndarray - * @return the result of the divide - */ @Override public INDArray div(INDArray other, INDArray result) { validateNumericalArray("div", true); return divi(other, result); } - /** - * copy (element wise) multiplication of two matrices - * - * @param other the second ndarray to multiply - * @return the result of the addition - */ @Override public INDArray mul(INDArray other) { validateNumericalArray("mul", false); @@ -3134,24 +3114,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * copy (element wise) multiplication of two matrices - * - * @param other the second ndarray to multiply - * @param result the result ndarray - * @return the result of the multiplication - */ @Override public INDArray mul(INDArray other, INDArray result) { return muli(other, result); } - /** - * copy subtraction of two matrices - * - * @param other the second ndarray to subtract - * @return the result of the addition - */ @Override public INDArray sub(INDArray other) { validateNumericalArray("sub", false); @@ -3162,24 +3129,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * copy subtraction of two matrices - * - * @param other the second ndarray to subtract - * @param result the result ndarray - * @return the result of the subtraction - */ @Override public INDArray sub(INDArray other, INDArray result) { return subi(other, result); } - /** - * copy addition of two matrices - * - * @param other the second ndarray to add - * @return the result of the addition - */ @Override public INDArray add(INDArray other) { validateNumericalArray("add", false); @@ -3190,65 +3144,29 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * copy addition of two matrices - * - * @param other the second ndarray to add - * @param result the result ndarray - * @return the result of the addition - */ @Override public INDArray add(INDArray other, INDArray result) { validateNumericalArray("add", false); return addi(other, result); } - - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param transpose the transpose status of each ndarray - * @return the result of the matrix multiplication - */ @Override public INDArray mmuli(INDArray other, MMulTranspose transpose) { validateNumericalArray("mmuli", false); return dup().mmuli(other, this,transpose); } - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @return the result of the matrix multiplication - */ @Override public INDArray mmuli(INDArray other) { validateNumericalArray("mmuli", false); return dup().mmuli(other, this); } - - /** - * Perform an in place matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param result the result ndarray - * @return the result of the matrix multiplication - */ @Override public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) { return transpose.exec(this, other, result); } - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param result the result ndarray - * @return the result of the matrix multiplication - */ @Override public INDArray mmuli(INDArray other, INDArray result) { validateNumericalArray("mmuli", false); @@ -3347,24 +3265,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.create(shape, stride); } - /** - * in place (element wise) division of two matrices - * - * @param other the second ndarray to divide - * @return the result of the divide - */ @Override public INDArray divi(INDArray other) { return divi(other, this); } - /** - * in place (element wise) division of two matrices - * - * @param other the second ndarray to divide - * @param result the result ndarray - * @return the result of the divide - */ @Override public INDArray divi(INDArray other, INDArray result) { validateNumericalArray("divi", false); @@ -3373,24 +3278,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - /** - * in place (element wise) multiplication of two matrices - * - * @param other the second ndarray to multiply - * @return the result of the multiplication - */ @Override public INDArray muli(INDArray other) { return muli(other, this); } - /** - * in place (element wise) multiplication of two matrices - * - * @param other the second ndarray to multiply - * @param result the result ndarray - * @return the result of the multiplication - */ @Override public INDArray muli(INDArray other, INDArray result) { validateNumericalArray("muli", false); @@ -3399,12 +3291,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - /** - * in place subtraction of two matrices - * - * @param other the second ndarray to subtract - * @return the result of the addition - */ @Override public INDArray subi(INDArray other) { return subi(other, this); @@ -3425,24 +3311,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - /** - * in place addition of two matrices - * - * @param other the second ndarray to add - * @return the result of the addition - */ @Override public INDArray addi(INDArray other) { return addi(other, this); } - /** - * in place addition of two matrices - * - * @param other the second ndarray to add - * @param result the result ndarray - * @return the result of the addition - */ @Override public INDArray addi(INDArray other, INDArray result) { validateNumericalArray("addi", false); @@ -3451,25 +3324,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - /** - * Returns the normmax along the specified dimension - * - * @param dimension the dimension to getScalar the norm1 along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the norm1 along the specified dimension - */ @Override public INDArray normmax(boolean keepDims, int... dimension) { validateNumericalArray("normmax", false); return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension)); } - /** - * Returns the normmax along the specified dimension - * - * @param dimension the dimension to getScalar the norm1 along - * @return the norm1 along the specified dimension - */ @Override public INDArray normmax(int... dimension) { return normmax(false, dimension); @@ -4071,49 +3931,23 @@ public abstract class BaseNDArray implements INDArray, Iterable { return reshape(Nd4j.order(), shape); } - /** - * Returns the product along a given dimension - * - * @param dimension the dimension to getScalar the product along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the product along the specified dimension - */ @Override public INDArray prod(boolean keepDims, int... dimension) { validateNumericalArray("prod", false); return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension)); } - /** - * Returns the product along a given dimension - * - * @param dimension the dimension to getScalar the product along - * @return the product along the specified dimension - */ @Override public INDArray prod(int... dimension) { return prod(false, dimension); } - /** - * Returns the overall mean of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray mean(boolean keepDims, int... dimension) { validateNumericalArray("mean", false); return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension)); } - /** - * Returns the overall mean of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray mean(int... dimension) { return mean(false, dimension); @@ -4136,50 +3970,24 @@ public abstract class BaseNDArray implements INDArray, Iterable { return mean(result, false, dimension); } - /** - * Returns the overall variance of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray var(int... dimension) { validateNumericalArray("var", false); return Nd4j.getExecutioner().exec(new Variance(this, dimension)); } - /** - * Returns the overall variance of this ndarray - * - * @param biasCorrected boolean on whether to apply corrected bias - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray var(boolean biasCorrected, int... dimension) { validateNumericalArray("var", false); return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension)); } - /** - * Returns the overall max of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray max(boolean keepDims, int... dimension) { validateNumericalArray("max", false); return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension)); } - /** - * Returns the overall max of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray max(int... dimension) { return max(false, dimension); @@ -4191,25 +3999,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new AMax(this, dimension)); } - /** - * Returns the overall min of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray min(boolean keepDims, int... dimension) { validateNumericalArray("min", false); return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension)); } - /** - * Returns the overall min of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray min(int... dimension) { return min(false, dimension); @@ -4290,39 +4085,17 @@ public abstract class BaseNDArray implements INDArray, Iterable { return sum(result, false, dimension); } - - /** - * Returns the norm1 along the specified dimension - * - * @param dimension the dimension to getScalar the norm1 along - * @return the norm1 along the specified dimension - */ @Override public INDArray norm1(int... dimension) { return norm1(false, dimension); } - - /** - * Returns the norm1 along the specified dimension - * - * @param dimension the dimension to getScalar the norm1 along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the norm1 along the specified dimension - */ @Override public INDArray norm1(boolean keepDims, int... dimension) { validateNumericalArray("norm1", false); return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension)); } - - /** - * Standard deviation of an ndarray along a dimension - * - * @param dimension the dimension to getScalar the std along - * @return the standard deviation along a particular dimension - */ @Override public INDArray std(int... dimension) { return std(true, dimension); @@ -4345,32 +4118,17 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0); } - /** - * Returns the norm2 along the specified dimension - * - * @param dimension the dimension to getScalar the norm2 along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the norm2 along the specified dimension - */ @Override public INDArray norm2(boolean keepDims, int... dimension) { validateNumericalArray("norm2", false); return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension)); } - /** - * Returns the norm2 along the specified dimension - * - * @param dimension the dimension to getScalar the norm2 along - * @return the norm2 along the specified dimension - */ @Override public INDArray norm2(int... dimension) { return norm2(false, dimension); } - - /** * Number of columns (shape[1]), throws an exception when * called when not 2d diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java index 6a112b868..1e0772494 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java @@ -1232,8 +1232,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return null; } - - @Override public INDArray normmax(boolean keepDims, int... dimension) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index b842797f9..47e259b94 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -1404,7 +1404,13 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray add(INDArray other, INDArray result); - + /** + * Perform an copy matrix multiplication + * + * @param other the other matrix to perform matrix multiply with + * @param transpose the transpose status of each ndarray + * @return the result of the matrix multiplication + */ INDArray mmuli(INDArray other, MMulTranspose transpose); /** @@ -1415,7 +1421,13 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray mmuli(INDArray other); - + /** + * Perform an in place matrix multiplication + * + * @param other the other matrix to perform matrix multiply with + * @param result the result ndarray + * @return the result of the matrix multiplication + */ INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose); /** @@ -1497,7 +1509,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray addi(INDArray other, INDArray result); - /** * Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s) * @@ -1506,7 +1517,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray normmax(int... dimension); - /** * Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s) * @@ -1585,7 +1595,7 @@ public interface INDArray extends Serializable, AutoCloseable { /** * Calculate the standard deviation for the entire array * - * @return + * @return standard deviation */ Number stdNumber(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java index f4624a6ee..0e2a4c6b9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java @@ -33,8 +33,9 @@ public class EluBp extends DynamicCustomOp { public EluBp(){ } - public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){ + public EluBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){ super(sd, new SDVariable[]{input, gradient}); + addTArgument(alpha); } public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java index a144e868b..6923639fd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java @@ -23,13 +23,9 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; import java.util.Collections; import java.util.List; -import java.util.Map; /** * ELU: Exponential Linear Unit (alpha=1.0)
@@ -41,19 +37,31 @@ import java.util.Map; * @author Alex Black */ public class ELU extends DynamicCustomOp { + public static final double DEFAULT_ALPHA = 1.0; + + protected double alpha; + public ELU(SameDiff sameDiff, SDVariable i_v) { super(sameDiff, new SDVariable[]{i_v}); + this.alpha = DEFAULT_ALPHA; + addTArgument(alpha); } public ELU() { } public ELU(INDArray x, INDArray z) { + this(x, z, DEFAULT_ALPHA); + } + + public ELU(INDArray x, INDArray z, double alpha) { super(null, wrapOrNull(x), wrapOrNull(z)); + this.alpha = alpha; + addTArgument(alpha); } public ELU(INDArray x) { - this(x, null); + this(x, null, DEFAULT_ALPHA); } @Override @@ -75,7 +83,7 @@ public class ELU extends DynamicCustomOp { public List doDiff(List i_v) { //ELU: e^x-1 if x<0, x otherwise //dL/dIn = dL/Out * dOut/dIn - return Collections.singletonList(f().eluBp(arg(), i_v.get(0))); + return Collections.singletonList(f().eluBp(arg(), i_v.get(0), alpha)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaEvent_t.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaEvent_t.java index 52b7d7332..de1920f0a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaEvent_t.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaEvent_t.java @@ -18,6 +18,7 @@ package org.nd4j.jita.allocator.pointers.cuda; import lombok.Getter; import lombok.Setter; +import lombok.val; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.linalg.exception.ND4JException; @@ -69,8 +70,9 @@ public class cudaEvent_t extends CudaPointer { if (res == 0) throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]"); - if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0) - throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage()); + val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode(); + if (code != 0) + throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code); } } @@ -78,8 +80,9 @@ public class cudaEvent_t extends CudaPointer { if (!isDestroyed()) { int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream); - if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0) - throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage()); + val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode(); + if (code != 0) + throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code); } } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index fdd40f8cb..9b8c1012c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -16,6 +16,7 @@ package org.nd4j.jita.handler.impl; +import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.Table; import lombok.Getter; @@ -105,6 +106,8 @@ public class CudaZeroHandler implements MemoryHandler { private final AllocationStatus INITIAL_LOCATION; + private final List cublasHandles = new ArrayList<>(); + private final AffinityManager affinityManager = Nd4j.getAffinityManager(); /* @@ -162,6 +165,7 @@ public class CudaZeroHandler implements MemoryHandler { int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices(); for (int i = 0; i < numDevices; i++) { deviceAllocations.add(new ConcurrentHashMap()); + cublasHandles.add(null); } if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) { @@ -1176,6 +1180,25 @@ public class CudaZeroHandler implements MemoryHandler { return getCudaContext(); } + // + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + + protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) { + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + try { + lock.writeLock().lock(); + + if (cublasHandles.get(deviceId) == null) { + cublasHandles.remove(deviceId); + cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc))); + } + + return cublasHandles.get(deviceId); + } finally { + lock.writeLock().unlock(); + } + } + /** * This method returns CudaContext for current thread. If context doesn't exist - it gets created first. * @return @@ -1183,8 +1206,6 @@ public class CudaZeroHandler implements MemoryHandler { public CudaContext getCudaContext() { val lc = nativeOps.defaultLaunchContext(); - // TODO: maybe make ThreadLocal cache for context? - return CudaContext.builder() .bufferScalar(nativeOps.lcScalarPointer(lc)) .bufferReduction(nativeOps.lcReductionPointer(lc)) @@ -1192,7 +1213,7 @@ public class CudaZeroHandler implements MemoryHandler { .bufferSpecial(nativeOps.lcScalarPointer(lc)) .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) - .cublasHandle(new cublasHandle_t(nativeOps.lcBlasHandle(lc))) + .cublasHandle(getCudaCublasHandle(lc)) .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) .build(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java index 7f8f9bb51..b06211545 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.jcublas.blas; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.DoublePointer; import org.bytedeco.javacpp.FloatPointer; @@ -52,6 +53,7 @@ import static org.nd4j.linalg.jcublas.blas.CudaBlas.*; * * @author Adam Gibson */ +@Slf4j public class JcublasLevel3 extends BaseLevel3 { private Allocator allocator = AtomicAllocator.getInstance(); private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas(); @@ -78,7 +80,7 @@ public class JcublasLevel3 extends BaseLevel3 { int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture(); - if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch == 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) { + if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch >= 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) { // on these selected archs we run with cublasHgemm __half alphaHalf = new __half(); __half betaHalf = new __half(); @@ -96,7 +98,11 @@ public class JcublasLevel3 extends BaseLevel3 { new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda, (ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta), (ShortPointer) cCPointer.getDevicePointer(), 2, ldc); + + } + + ctx.getOldStream().synchronize(); } allocator.registerAction(ctx, C, A, B); @@ -114,18 +120,24 @@ public class JcublasLevel3 extends BaseLevel3 { val ctx = allocator.getFlowController().prepareAction(C, A, B); + //log.info("Synchronizing CUDA stream"); + ctx.getOldStream().synchronize(); + val cAPointer = new CublasPointer(A, ctx); val cBPointer = new CublasPointer(B, ctx); val cCPointer = new CublasPointer(C, ctx); val handle = ctx.getCublasHandle(); synchronized (handle) { + //log.info("Handle: {}; Stream: {}", handle.address(), ctx.getCublasStream().address()); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K, new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda, (FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta), (FloatPointer) cCPointer.getDevicePointer(), ldc); + + ctx.getOldStream().synchronize(); } allocator.registerAction(ctx, C, A, B); @@ -244,6 +256,8 @@ public class JcublasLevel3 extends BaseLevel3 { new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda, (DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta), (DoublePointer) cCPointer.getDevicePointer(), ldc); + + ctx.getOldStream().synchronize(); } allocator.registerAction(ctx, C, A, B); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index c5b02a82f..43bbfbdca 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -2548,6 +2548,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); if (nativeOps.lastErrorCode() != 0) @@ -2562,6 +2565,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); if (nativeOps.lastErrorCode() != 0) @@ -2577,6 +2583,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createConstantBuffer(long[] values, DataType desiredType) { + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length); if (nativeOps.lastErrorCode() != 0) @@ -2590,6 +2599,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createConstantBuffer(double[] values, DataType desiredType) { + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length); if (nativeOps.lastErrorCode() != 0) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 15f6c52ef..f3080f05a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc * @param writeList * @param readList */ + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list + + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 6983e20f0..9554a94e9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc * @param writeList * @param readList */ + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list + + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list /** @@ -16982,8 +16985,20 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #endif /** - * Returns a batched matrix tensor with new batched diagonal values. - */ + * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array + * + * Input arrays: + * input: input array, considered as batch of matrices + * diagonal: array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank - 1, + * the shapes of diagonal and input arrays must be equal except last dimension of input array, + * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], + * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions + * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * + * Output array: + * has the same shape as input, corresponding diagonal elements are substituted + */ // #if NOT_EXCLUDED(OP_matrix_set_diag) @Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp { static { Loader.load(); }