diff --git a/arbiter/pom.xml b/arbiter/pom.xml
index 364c6d904..93f877968 100644
--- a/arbiter/pom.xml
+++ b/arbiter/pom.xml
@@ -192,7 +192,7 @@
maven-surefire-plugin
${maven-surefire-plugin.version}
- -Ddtype=double -Xmx3024m -Xms3024m
+ -Ddtype=double -Dfile.encoding=UTF-8 -Xmx3024m -Xms3024m
*.java
diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java
index a35c612b1..b0757e67f 100644
--- a/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java
+++ b/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java
@@ -18,6 +18,8 @@ package org.deeplearning4j.nn.dataimport.solr.client.solrj.io.stream;
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
import com.carrotsearch.randomizedtesting.ThreadFilter;
+
+import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -44,6 +46,18 @@ import org.nd4j.rng.deallocator.NativeRandomDeallocator;
})
public class TupleStreamDataSetIteratorTest extends SolrCloudTestCase {
+ static {
+ /*
+ This is a hack around the backend-dependent nature of secure random implementations
+ though we can set the secure random algorithm in our pom.xml files (via maven surefire and test.solr.allowed.securerandom)
+ there isn't a mechanism that is completely platform independent.
+ By setting it there (for example, to NativePRNG) that makes it pass on some platforms like Linux but fails on some JVMs on Windows
+ For testing purposes, we don't need strict guarantees around RNG, hence we don't want to enforce the RNG algorithm
+ */
+ String algorithm = new SecureRandom().getAlgorithm();
+ System.setProperty("test.solr.allowed.securerandom", algorithm);
+ }
+
public static class PrivateDeallocatorThreadsFilter implements ThreadFilter {
/**
* Reject deallocator threads over whose cleanup this test has no control.
diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java
index 39e91921a..c1aedd47a 100644
--- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java
+++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java
@@ -66,7 +66,7 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new long[] {vectorSize}, vector.shape());
- System.out.println(Arrays.toString(vector.dup().data().asFloat()));
+// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
}
GraphWalkIterator iter = new RandomWalkIterator<>(graph, 8);
@@ -182,10 +182,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
fail(msg);
- else
- System.out.println(msg);
+// else
+// System.out.println(msg);
}
- System.out.println();
+// System.out.println();
}
}
@@ -216,7 +216,7 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
for (int i = 0; i < nVertices; i++) {
INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new long[] {vectorSize}, vector.shape());
- System.out.println(Arrays.toString(vector.dup().data().asFloat()));
+// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
}
GraphWalkIterator iter = new RandomWalkIterator<>(graph, 10);
@@ -295,8 +295,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
if (relError > MAX_REL_ERROR && absErr > minAbsError)
fail(msg);
- else
- System.out.println(msg);
+// else
+// System.out.println(msg);
}
}
diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml
index 02ce30a40..383eb1c8c 100644
--- a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml
+++ b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml
@@ -33,7 +33,7 @@
org.apache.maven.plugins
maven-surefire-plugin
- -Ddtype=float -Xmx8g -Dtest.solr.allowed.securerandom=NativePRNG
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g -Dtest.solr.allowed.securerandom=NativePRNG
*.java
diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java
index 899e3f8fd..31633889a 100644
--- a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java
+++ b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelexport.solr.handler;
import java.io.File;
import java.nio.file.Path;
+import java.security.SecureRandom;
import com.carrotsearch.randomizedtesting.ThreadFilter;
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
@@ -49,6 +50,19 @@ import org.nd4j.rng.deallocator.NativeRandomDeallocator;
})
public class ModelTupleStreamIntegrationTest extends SolrCloudTestCase {
+ static {
+ /*
+ This is a hack around the backend-dependent nature of secure random implementations
+ though we can set the secure random algorithm in our pom.xml files (via maven surefire and test.solr.allowed.securerandom)
+ there isn't a mechanism that is completely platform independent.
+ By setting it there (for example, to NativePRNG) that makes it pass on some platforms like Linux but fails on some JVMs on Windows
+ For testing purposes, we don't need strict guarantees around RNG, hence we don't want to enforce the RNG algorithm
+ */
+ String algorithm = new SecureRandom().getAlgorithm();
+ System.setProperty("test.solr.allowed.securerandom", algorithm);
+ }
+
+
public static class PrivateDeallocatorThreadsFilter implements ThreadFilter {
/**
* Reject deallocator threads over whose cleanup this test has no control.
diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java
index 1cafc143d..80073677f 100644
--- a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java
+++ b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java
@@ -19,6 +19,7 @@ package org.deeplearning4j.nn.modelexport.solr.handler;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
+import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -58,6 +59,18 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
public class ModelTupleStreamTest {
+ static {
+ /*
+ This is a hack around the backend-dependent nature of secure random implementations
+ though we can set the secure random algorithm in our pom.xml files (via maven surefire and test.solr.allowed.securerandom)
+ there isn't a mechanism that is completely platform independent.
+ By setting it there (for example, to NativePRNG) that makes it pass on some platforms like Linux but fails on some JVMs on Windows
+ For testing purposes, we don't need strict guarantees around RNG, hence we don't want to enforce the RNG algorithm
+ */
+ String algorithm = new SecureRandom().getAlgorithm();
+ System.setProperty("test.solr.allowed.securerandom", algorithm);
+ }
+
protected List floatsList(int numFloats) {
final List floatsList = new ArrayList();
final float[] floats0 = new float[numFloats];
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
index 2d4a4da14..911432cf0 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
@@ -36,7 +36,7 @@
org.apache.maven.plugins
maven-surefire-plugin
- -Ddtype=float -Xmx8g
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g
*.java
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java
index 2f2619e78..abbfa04bc 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java
@@ -38,6 +38,11 @@ public class KMeansTest extends BaseDL4JTest {
private boolean[] useKMeansPlusPlus = {true, false};
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 60000L;
+ }
+
@Test
public void testKMeans() {
Nd4j.getRandom().setSeed(7);
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java
index 7807ff711..0154bc732 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java
@@ -16,25 +16,19 @@
package org.deeplearning4j.models;
-import org.junit.rules.Timeout;
-import org.nd4j.shade.guava.io.Files;
-import org.nd4j.shade.guava.primitives.Doubles;
import lombok.val;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.BaseDL4JTest;
-import org.deeplearning4j.models.sequencevectors.SequenceVectors;
-import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory;
-import org.junit.Rule;
-import org.junit.rules.TemporaryFolder;
-import org.nd4j.linalg.io.ClassPathResource;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
+import org.deeplearning4j.models.sequencevectors.SequenceVectors;
+import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
@@ -48,11 +42,16 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.junit.Before;
import org.junit.Ignore;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.rules.Timeout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.resources.Resources;
+import org.nd4j.shade.guava.primitives.Doubles;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -272,7 +271,14 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
@Test
public void testFullModelSerialization() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
+
File inputFile = Resources.asFile("big/raw_sentences.txt");
+
+
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
@@ -892,5 +898,4 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
fail(e.getMessage());
}
}
-
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java
index e50a95443..7dcfb160a 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java
@@ -159,6 +159,11 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testWord2VecCBOW() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
+
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
TokenizerFactory t = new DefaultTokenizerFactory();
@@ -188,6 +193,11 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testWord2VecMultiEpoch() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
+
SentenceIterator iter;
if(isIntegrationTests()){
iter = new BasicLineIterator(inputFile.getAbsolutePath());
@@ -220,6 +230,11 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void reproducibleResults_ForMultipleRuns() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
+
log.info("reproducibleResults_ForMultipleRuns");
val shakespear = new ClassPathResource("big/rnj.txt");
val basic = new ClassPathResource("big/rnj.txt");
@@ -274,6 +289,11 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testRunWord2Vec() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
+
// Strip white space before and after for each line
/*val shakespear = new ClassPathResource("big/rnj.txt");
SentenceIterator iter = new BasicLineIterator(shakespear.getFile());*/
@@ -363,6 +383,11 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testLoadingWordVectors() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
+
File modelFile = new File(pathToWriteto);
if (!modelFile.exists()) {
testRunWord2Vec();
@@ -396,6 +421,11 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testW2VnegativeOnRestore() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
+
// Strip white space before and after for each line
SentenceIterator iter;
if(isIntegrationTests()){
@@ -453,6 +483,11 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testUnknown1() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
+
// Strip white space before and after for each line
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
@@ -688,6 +723,10 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testWordVectorsPartiallyAbsentLabels() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
@@ -720,6 +759,10 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testWordVectorsAbsentLabels() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
@@ -745,6 +788,10 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testWordVectorsAbsentLabels_WithUnknown() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
@@ -814,6 +861,10 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void weightsNotUpdated_WhenLocked_CBOW() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
@@ -851,6 +902,11 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testWordsNearestSum() throws IOException {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
+ }
+
log.info("Load & Vectorize Sentences....");
SentenceIterator iter = new BasicLineIterator(inputFile);
TokenizerFactory t = new DefaultTokenizerFactory();
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java
index c99cb3b9a..cf0e7c7a3 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java
@@ -48,12 +48,22 @@ public class TsneTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 60000L;
+ return 180000L;
}
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
+ @Override
+ public DataType getDataType() {
+ return DataType.FLOAT;
+ }
+
+ @Override
+ public DataType getDefaultFPDataType() {
+ return DataType.FLOAT;
+ }
+
@Test
public void testSimple() throws Exception {
//Simple sanity check
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java
index 95cd4e9a6..14495ffaf 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java
@@ -32,6 +32,7 @@ import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.Par
import org.deeplearning4j.text.sentenceiterator.*;
import org.junit.Rule;
import org.junit.rules.TemporaryFolder;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.io.ClassPathResource;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
@@ -80,12 +81,21 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 240000;
+ return isIntegrationTests() ? 600_000 : 240_000;
}
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
+ @Override
+ public DataType getDataType() {
+ return DataType.FLOAT;
+ }
+
+ @Override
+ public DataType getDefaultFPDataType() {
+ return DataType.FLOAT;
+ }
/*
@Test
@@ -359,8 +369,13 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
}
- @Test(timeout = 300000)
+ @Test
public void testParagraphVectorsDM() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed
+ }
+
File file = Resources.asFile("/big/raw_sentences.txt");
SentenceIterator iter = new BasicLineIterator(file);
@@ -372,10 +387,10 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
LabelsSource source = new LabelsSource("DOC_");
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(1)
- .layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter)
- .trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0)
- .useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true)
- .sequenceLearningAlgorithm(new DM()).build();
+ .layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter)
+ .trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0)
+ .useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true)
+ .sequenceLearningAlgorithm(new DM()).build();
vec.fit();
@@ -404,7 +419,9 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
double similarityX = vec.similarity("DOC_3720", "DOC_9852");
log.info("3720/9852 similarity: " + similarityX);
- assertTrue(similarityX < 0.5d);
+ if(isIntegrationTests()) {
+ assertTrue(similarityX < 0.5d);
+ }
// testing DM inference now
@@ -418,7 +435,6 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
log.info("Cos O/A: {}", cosAO1);
log.info("Cos A/B: {}", cosAB1);
-
}
@@ -501,6 +517,11 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
@Test(timeout = 300000)
public void testParagraphVectorsWithWordVectorsModelling1() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed
+ }
+
File file = Resources.asFile("/big/raw_sentences.txt");
SentenceIterator iter = new BasicLineIterator(file);
@@ -705,8 +726,12 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
In this test we'll build w2v model, and will use it's vocab and weights for ParagraphVectors.
there's no need in this test within travis, use it manually only for problems detection
*/
- @Test(timeout = 300000)
+ @Test
public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception {
+ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
+ if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
+ skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed
+ }
// we build w2v from multiple sources, to cover everything
File resource_sentences = Resources.asFile("/big/raw_sentences.txt");
@@ -997,14 +1022,18 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
log.info("SimilarityB: {}", simB);
}
- @Test(timeout = 300000)
+ @Test
+ @Ignore //AB 2020/02/06 - https://github.com/eclipse/deeplearning4j/issues/8677
public void testDirectInference() throws Exception {
- File resource_sentences = Resources.asFile("/big/raw_sentences.txt");
+ boolean isIntegration = isIntegrationTests();
+ File resource = Resources.asFile("/big/raw_sentences.txt");
+ SentenceIterator sentencesIter = getIterator(isIntegration, resource);
+
ClassPathResource resource_mixed = new ClassPathResource("paravec/");
File local_resource_mixed = testDir.newFolder();
resource_mixed.copyDirectory(local_resource_mixed);
SentenceIterator iter = new AggregatingSentenceIterator.Builder()
- .addSentenceIterator(new BasicLineIterator(resource_sentences))
+ .addSentenceIterator(sentencesIter)
.addSentenceIterator(new FileSentenceIterator(local_resource_mixed)).build();
TokenizerFactory t = new DefaultTokenizerFactory();
@@ -1154,24 +1183,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
public void testDoubleFit() throws Exception {
boolean isIntegration = isIntegrationTests();
File resource = Resources.asFile("/big/raw_sentences.txt");
- SentenceIterator iter;
- if(isIntegration){
- iter = new BasicLineIterator(resource);
- } else {
- List lines = new ArrayList<>();
- try(InputStream is = new BufferedInputStream(new FileInputStream(resource))){
- LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
- try{
- for( int i=0; i<500 && lineIter.hasNext(); i++ ){
- lines.add(lineIter.next());
- }
- } finally {
- lineIter.close();
- }
- }
-
- iter = new CollectionSentenceIterator(lines);
- }
+ SentenceIterator iter = getIterator(isIntegration, resource);
TokenizerFactory t = new DefaultTokenizerFactory();
@@ -1197,6 +1209,30 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
assertEquals(num1, num2);
}
+
+ public static SentenceIterator getIterator(boolean isIntegration, File file) throws IOException {
+ return getIterator(isIntegration, file, 500);
+ }
+
+ public static SentenceIterator getIterator(boolean isIntegration, File file, int linesForUnitTest) throws IOException {
+ if(isIntegration){
+ return new BasicLineIterator(file);
+ } else {
+ List lines = new ArrayList<>();
+ try(InputStream is = new BufferedInputStream(new FileInputStream(file))){
+ LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
+ try{
+ for( int i=0; i data = MLUtils
.loadLibSVMFile(sc.sc(),
@@ -125,7 +142,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
- @Test(timeout = 120000L)
+ @Test
public void testFromSvmLight() throws Exception {
JavaRDD data = MLUtils
.loadLibSVMFile(sc.sc(),
@@ -155,7 +172,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
master.fitLabeledPoint(data);
}
- @Test(timeout = 120000L)
+ @Test
public void testRunIteration() {
DataSet dataSet = new IrisDataSetIterator(5, 5).next();
@@ -175,7 +192,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertEquals(expectedParams.size(1), actualParams.size(1));
}
- @Test(timeout = 120000L)
+ @Test
public void testUpdaters() {
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
@@ -197,7 +214,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
- @Test(timeout = 120000L)
+ @Test
public void testEvaluation() {
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
@@ -228,7 +245,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
}
- @Test(timeout = 120000L)
+ @Test
public void testSmallAmountOfData() {
//Idea: Test spark training where some executors don't get any data
//in this case: by having fewer examples (2 DataSets) than executors (local[*])
@@ -255,7 +272,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
- @Test(timeout = 120000L)
+ @Test
public void testDistributedScoring() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1)
@@ -333,7 +350,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
- @Test(timeout = 120000L)
+ @Test
public void testParameterAveragingMultipleExamplesPerDataSet() throws Exception {
int dataSetObjSize = 5;
int batchSizePerExecutor = 25;
@@ -382,7 +399,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
- @Test(timeout = 120000L)
+ @Test
public void testFitViaStringPaths() throws Exception {
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath();
@@ -445,7 +462,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
sparkNet.getTrainingMaster().deleteTempFiles(sc);
}
- @Test(timeout = 120000L)
+ @Test
public void testFitViaStringPathsSize1() throws Exception {
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsSize1").toPath();
@@ -525,7 +542,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
- @Test(timeout = 120000L)
+ @Test
public void testFitViaStringPathsCompGraph() throws Exception {
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsCG").toPath();
@@ -618,7 +635,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
- @Test(timeout = 120000L)
+ @Test
@Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue")
public void testSeedRepeatability() throws Exception {
@@ -691,7 +708,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
- @Test(timeout = 120000L)
+ @Test
public void testIterationCounts() throws Exception {
int dataSetObjSize = 5;
int batchSizePerExecutor = 25;
@@ -737,7 +754,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
}
- @Test(timeout = 120000L)
+ @Test
public void testIterationCountsGraph() throws Exception {
int dataSetObjSize = 5;
int batchSizePerExecutor = 25;
@@ -783,7 +800,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
- @Test(timeout = 120000L) @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
+ @Test
+ @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
public void testVaePretrainSimple() {
//Simple sanity check on pretraining
int nIn = 8;
@@ -818,7 +836,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
sparkNet.fit(data);
}
- @Test(timeout = 120000L) @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
+ @Test
+ @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
public void testVaePretrainSimpleCG() {
//Simple sanity check on pretraining
int nIn = 8;
@@ -854,7 +873,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
- @Test(timeout = 120000L)
+ @Test
public void testROC() {
int nArrays = 100;
@@ -909,7 +928,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
}
- @Test(timeout = 120000L)
+ @Test
public void testROCMultiClass() {
int nArrays = 100;
diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java
index c92f5acdc..9dea6629a 100644
--- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java
+++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java
@@ -34,7 +34,7 @@ public class MiscTests extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 120000L;
+ return 240000L;
}
@Test
diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml
index bcf633e3a..17c89f931 100644
--- a/deeplearning4j/pom.xml
+++ b/deeplearning4j/pom.xml
@@ -380,7 +380,7 @@
-->
true
false
- -Ddtype=float -Xmx8g
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g
*.java
diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h
index 01b656861..f4d1d261b 100755
--- a/libnd4j/blas/NativeOps.h
+++ b/libnd4j/blas/NativeOps.h
@@ -1601,6 +1601,7 @@ ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext*
ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow);
ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride);
ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode);
+ND4J_EXPORT void ctxPurge(OpaqueContext* ptr);
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp
index 0410c833b..b945c5bcf 100644
--- a/libnd4j/blas/cpu/NativeOps.cpp
+++ b/libnd4j/blas/cpu/NativeOps.cpp
@@ -2815,6 +2815,10 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
}
+void ctxPurge(OpaqueContext* ptr) {
+ ptr->clearFastPath();
+}
+
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
}
diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu
index d65dcaed5..07ce876ea 100755
--- a/libnd4j/blas/cuda/NativeOps.cu
+++ b/libnd4j/blas/cuda/NativeOps.cu
@@ -3771,6 +3771,10 @@ void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) {
ptr->setShapeFunctionOverride(reallyOverride);
}
+void ctxPurge(OpaqueContext* ptr) {
+ ptr->clearFastPath();
+}
+
int binaryLevel() {
return 0;
}
diff --git a/libnd4j/include/array/impl/DataBuffer.cpp b/libnd4j/include/array/impl/DataBuffer.cpp
index 49527026c..36758c684 100644
--- a/libnd4j/include/array/impl/DataBuffer.cpp
+++ b/libnd4j/include/array/impl/DataBuffer.cpp
@@ -305,12 +305,17 @@ namespace nd4j {
if (_primaryBuffer != nullptr && _isOwnerPrimary) {
deletePrimary();
}
+
_primaryBuffer = buffer;
_isOwnerPrimary = false;
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
}
void DataBuffer::setSpecialBuffer(void *buffer, size_t length) {
+ if (_specialBuffer != nullptr && _isOwnerSpecial) {
+ deleteSpecial();
+ }
+
this->setSpecial(buffer, false);
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
}
diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h
index 96b7e1c79..d1e8a4dad 100644
--- a/libnd4j/include/graph/Context.h
+++ b/libnd4j/include/graph/Context.h
@@ -204,6 +204,13 @@ namespace nd4j {
void setBArguments(const std::vector &tArgs);
void setDArguments(const std::vector &dArgs);
+ /**
+ * This method purges fastpath in/out contents and releases all the handles.
+ *
+ * PLEASE NOTE: I/T/B/D args will stay intact
+ */
+ void clearFastPath();
+
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);
void allowHelpers(bool reallyAllow);
diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp
index 4c7a19133..5add8280d 100644
--- a/libnd4j/include/graph/impl/Context.cpp
+++ b/libnd4j/include/graph/impl/Context.cpp
@@ -563,6 +563,16 @@ namespace nd4j {
for (auto d:dArgs)
_dArgs.emplace_back(d);
}
+
+ void Context::clearFastPath() {
+ _fastpath_in.clear();
+ _fastpath_out.clear();
+
+ for (auto v:_handles)
+ delete v;
+
+ _handles.clear();
+ }
}
}
diff --git a/libnd4j/include/helpers/cuda/ConstantHelper.cu b/libnd4j/include/helpers/cuda/ConstantHelper.cu
index 6c8eaa21d..47e276f4a 100644
--- a/libnd4j/include/helpers/cuda/ConstantHelper.cu
+++ b/libnd4j/include/helpers/cuda/ConstantHelper.cu
@@ -92,7 +92,7 @@ namespace nd4j {
}
void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) {
- _mutex.lock();
+ std::lock_guard lock(_mutex);
auto deviceId = getCurrentDevice();
Nd4jPointer constantPtr = nullptr;
@@ -116,7 +116,6 @@ namespace nd4j {
if (res != 0)
throw cuda_exception::build("cudaMemcpy failed", res);
- _mutex.unlock();
return ptr;
} else {
auto originalBytes = numBytes;
@@ -130,7 +129,6 @@ namespace nd4j {
if (res != 0)
throw cuda_exception::build("cudaMemcpyToSymbol failed", res);
- _mutex.unlock();
return reinterpret_cast(constantPtr) + constantOffset;
}
}
@@ -152,7 +150,7 @@ namespace nd4j {
ConstantDataBuffer* result;
// access to this holder instance is synchronous
- holder->mutex()->lock();
+ std::lock_guard lock(*holder->mutex());
if (holder->hasBuffer(dataType)) {
result = holder->getConstantDataBuffer(dataType);
@@ -175,8 +173,6 @@ namespace nd4j {
holder->addBuffer(dataBuffer, dataType);
result = holder->getConstantDataBuffer(dataType);
}
- // release holder lock
- holder->mutex()->unlock();
return result;
}
diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu
index aae62594c..4f7a4a485 100644
--- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu
+++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu
@@ -57,7 +57,7 @@ namespace nd4j {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = AffinityManager::currentDeviceId();
- _mutex.lock();
+ std::lock_guard lock(_mutex);
if (_cache[deviceId].count(descriptor) == 0) {
auto hPtr = descriptor.toShapeInfo();
@@ -65,15 +65,9 @@ namespace nd4j {
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
ShapeDescriptor descriptor1(descriptor);
_cache[deviceId][descriptor1] = buffer;
- auto r = _cache[deviceId][descriptor1];
- _mutex.unlock();
-
- return r;
+ return _cache[deviceId][descriptor1];
} else {
- ConstantDataBuffer r = _cache[deviceId].at(descriptor);
- _mutex.unlock();
-
- return r;
+ return _cache[deviceId].at(descriptor);
}
}
@@ -83,18 +77,10 @@ namespace nd4j {
}
bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) {
- bool result;
auto deviceId = AffinityManager::currentDeviceId();
- _mutex.lock();
+ std::lock_guard lock(_mutex);
- if (_cache[deviceId].count(descriptor) == 0)
- result = false;
- else
- result = true;
-
- _mutex.unlock();
-
- return result;
+ return _cache[deviceId].count(descriptor) != 0;
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
diff --git a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu
index 8ea4067f3..747e295e2 100644
--- a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu
+++ b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu
@@ -64,7 +64,7 @@ namespace nd4j {
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
const int deviceId = AffinityManager::currentDeviceId();
- _mutex.lock();
+ std::lock_guard lock(_mutex);
if (_cache[deviceId].count(descriptor) == 0) {
const auto shapeInfo = descriptor.originalShape().toShapeInfo();
@@ -97,14 +97,12 @@ namespace nd4j {
_cache[deviceId][descriptor] = t;
TadPack r = _cache[deviceId][descriptor];
- _mutex.unlock();
delete[] shapeInfo;
return r;
} else {
TadPack r = _cache[deviceId][descriptor];
- _mutex.unlock();
return r;
}
diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp
index b82d5306a..be905e22f 100644
--- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp
+++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp
@@ -169,8 +169,8 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
- REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
- REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
+ REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf());
+ REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
@@ -178,8 +178,8 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
- REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
- REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
+ REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
+ REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCDHW) {
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp
index e7694b409..477b298a3 100644
--- a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp
+++ b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp
@@ -58,30 +58,31 @@ namespace nd4j {
int outRank = shape::rank(in) + 1;
auto input = INPUT_VARIABLE(0);
auto dtype = DataType::BOOL;
- Nd4jLong maxInd = input->argMax();
- Nd4jLong max = input->e(maxInd);
+ auto argMaxInd = input->argMax();
+ Nd4jLong max = input->e(argMaxInd);
+ Nd4jLong maxInd = max;
- if (block.getIArguments()->size() > 0) {
- if (block.width() < 2) {
- maxInd = INT_ARG(0);
- if (maxInd < max)
- maxInd = static_cast(max);
- if (block.getIArguments()->size() > 1)
- dtype = (DataType)INT_ARG(1);
- }
- else {
- dtype = (DataType)INT_ARG(0);
- }
- }
+ if (block.numD() > 0)
+ dtype = D_ARG(0);
if (block.width() > 1) {
auto maxlen = INPUT_VARIABLE(1);
Nd4jLong tmaxlen = maxlen->e(0);
if (tmaxlen > max)
maxInd = static_cast(tmaxlen);
+ if (block.numI() > 0) {
+ dtype = (DataType) INT_ARG(0);
+ }
+ }
+ else {
+ if (block.numI() > 0) {
+ maxInd = INT_ARG(0);
+ }
+ if (maxInd < max)
+ maxInd = max;
+ if (block.numI() > 1)
+ dtype = (DataType)INT_ARG(1); // to work with legacy code
}
- else
- maxInd = static_cast(max);
int lastDimension = maxInd;
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);
diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp
index bf3463afe..c175fd96d 100644
--- a/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp
+++ b/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp
@@ -38,10 +38,10 @@ namespace helpers {
}
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
- BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
+ BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
}
- BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
+ BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
}
}
}
\ No newline at end of file
diff --git a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp
index 8583d9cba..48f7f0d9a 100644
--- a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp
+++ b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp
@@ -36,10 +36,12 @@ namespace helpers {
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
auto inputPart = input->allTensorsAlongDimension({-2, -1});
auto outputPart = output->allTensorsAlongDimension({-2, -1});
+ auto rows = input->sizeAt(-2);
output->assign(input);
+
auto batchLoop = PRAGMA_THREADS_FOR {
for (auto batch = start; batch < stop; batch += increment) {
- for (auto r = 0; r < input->rows(); r++) {
+ for (auto r = 0; r < rows; r++) {
for (auto c = 0; c < r; c++) {
math::nd4j_swap(outputPart[batch]->t(r, c) , outputPart[batch]->t(c, r));
}
diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp
index e904d219c..ceb228439 100644
--- a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp
+++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp
@@ -108,17 +108,20 @@ namespace helpers {
static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) {
auto inputPart = input->allTensorsAlongDimension({-2, -1});
auto outputPart = output->allTensorsAlongDimension({-2, -1});
+ auto cols = input->sizeAt(-1);
+ auto rows = input->sizeAt(-2);
+
auto batchLoop = PRAGMA_THREADS_FOR {
for (auto batch = start; batch < stop; batch += increment) {
if (!lower) {
- for (auto r = 0; r < input->rows(); r++) {
+ for (auto r = 0; r < rows; r++) {
for (auto c = 0; c <= r; c++) {
outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r);
}
}
} else {
- for (auto r = 0; r < input->rows(); r++) {
- for (auto c = r; c < input->columns(); c++) {
+ for (auto r = 0; r < rows; r++) {
+ for (auto c = r; c < cols; c++) {
outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r);
}
}
diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu b/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu
index c07db1b95..6b33a384e 100644
--- a/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu
+++ b/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu
@@ -55,10 +55,10 @@ namespace helpers {
}
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
- BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
+ BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
}
- BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
+ BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
}
}
}
\ No newline at end of file
diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu
index fa7b1ecfa..02a302e61 100644
--- a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu
+++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu
@@ -250,7 +250,7 @@ void pooling3dCUDNN(const LaunchContext* context,
auto handle = reinterpret_cast(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err);
-printf("fffffffffff\n");
+
const int numDims = 5;
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp
index 1c1e9d6a4..4c8a582f0 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp
@@ -17,6 +17,7 @@
//
// @author saudet
// @author raver119@gmail.com
+// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include
@@ -36,103 +37,44 @@ namespace platforms {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
- auto input = INPUT_VARIABLE(0);
- REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
- input->rankOf());
+ auto input = INPUT_VARIABLE(0);
+ auto output = OUTPUT_VARIABLE(0);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
- auto argI = *(block.getIArguments());
- auto output = OUTPUT_VARIABLE(0);
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
- int pH = INT_ARG(4);
- int pW = INT_ARG(5);
+ auto pH = INT_ARG(4);
+ auto pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
- const auto isSameMode = static_cast(INT_ARG(8));
+ const auto paddingMode = INT_ARG(8);
const auto extraParam0 = INT_ARG(9);
+ const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
- REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
- dH, dW);
+ REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf());
+ REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
- int oH = 0;
- int oW = 0;
+ int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
+ int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
- int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
-
- const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
- const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
-
- if (!isNCHW) {
- input = new NDArray(
- input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
- output = new NDArray(
- output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
- }
-
- ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
-
- if (isSameMode)
+ if (paddingMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
- const int bS = input->sizeAt(0);
- const int iC = input->sizeAt(1);
- const int oC = output->sizeAt(1);
+ auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
- auto poolingMode = PoolingType::AVG_POOL;
-
- dnnl_memory_desc_t empty;
- dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
- dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
- dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
- dnnl::algorithm algorithm;
- mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
- true,
- bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
- algorithm,
- &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
- &user_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
- auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
- pool_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
- auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
- auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
- auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
- auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
- auto pool_src_memory = user_src_memory;
- dnnl::stream stream(engine);
- if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
- pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
- reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
- }
- auto pool_dst_memory = user_dst_memory;
- if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
- pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
- }
- pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
- {DNNL_ARG_DST, pool_dst_memory}});
- if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
- reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
- }
- stream.wait();
-
- //streams[0].submitAndWait();
-
- if (!isNCHW) {
- delete input;
- delete output;
- }
+ mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
+
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
@@ -141,12 +83,10 @@ PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
- auto input = INPUT_VARIABLE(
- 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
- auto gradO = INPUT_VARIABLE(
- 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
- auto gradI = OUTPUT_VARIABLE(
- 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
+
+ auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
+ auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
+ auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width
@@ -156,92 +96,26 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
- int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
+ int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int extraParam0 = INT_ARG(9);
- int isNCHW =
- block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
+ int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
- REQUIRE_TRUE(input->rankOf() == 4, 0,
- "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
- REQUIRE_TRUE(dH != 0 && dW != 0, 0,
- "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
+ REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf());
+ REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
- ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
- indIiH, indWiC, indWoC, indWkH, indOoH);
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
- std::string expectedGradOShape = ShapeUtils::shapeAsString(
- ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
- std::string expectedGradIShape = ShapeUtils::shapeAsString(
- ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
- REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
- "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
- expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
- REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
- "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
- expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
+ std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
+ REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
-
- if (!isNCHW) {
- input = new NDArray(input->permute(
- {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
- gradI = new NDArray(gradI->permute(
- {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
- gradO = new NDArray(gradO->permute(
- {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
- }
-
- if (isSameMode) // SAME
+ if(paddingMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
- auto poolingMode = PoolingType::AVG_POOL;
-
- dnnl_memory_desc_t empty;
- dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
- dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
- dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
- dnnl::algorithm algorithm;
- mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
- true,
- bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
- &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
- &user_diff_src_md, &user_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
- auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
- input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
- pool_dst_md, pool_strides, pool_kernel, pool_padding,
- pool_padding_r);
- auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
- auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
- auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
- pool_kernel, pool_padding, pool_padding_r);
- auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
- auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
- auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
- auto poolB_src_memory = userB_src_memory;
- dnnl::stream stream(engine);
- if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
- poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
- }
- auto poolB_dst_memory = userB_dst_memory;
- if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
- poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
- reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
- }
- pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
- {DNNL_ARG_DIFF_SRC, poolB_src_memory}});
- if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
- reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
- }
- stream.wait();
-
- if (!isNCHW) {
- delete input;
- delete gradI;
- delete gradO;
- }
+ auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
+ mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode);
return Status::OK();
}
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp
index 2456625ef..39e85de98 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp
@@ -17,6 +17,7 @@
//
// @author saudet
// @author raver119@gmail.com
+// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include
@@ -29,113 +30,110 @@
using namespace dnnl;
-namespace nd4j {
- namespace ops {
- namespace platforms {
- PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) {
- auto input = INPUT_VARIABLE(
- 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
- auto output = OUTPUT_VARIABLE(
- 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
+namespace nd4j {
+namespace ops {
+namespace platforms {
- int kD = INT_ARG(0); // filter(kernel) depth
- int kH = INT_ARG(1); // filter(kernel) height
- int kW = INT_ARG(2); // filter(kernel) width
- int sD = INT_ARG(3); // strides depth
- int sH = INT_ARG(4); // strides height
- int sW = INT_ARG(5); // strides width
- int pD = INT_ARG(6); // paddings depth
- int pH = INT_ARG(7); // paddings height
- int pW = INT_ARG(8); // paddings width
- int dD = INT_ARG(9); // dilations depth
- int dH = INT_ARG(10); // dilations height
- int dW = INT_ARG(11); // dilations width
- int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
- int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
- int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
+//////////////////////////////////////////////////////////////////////
+PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) {
- REQUIRE_TRUE(input->rankOf() == 5, 0,
- "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
- input->rankOf());
- REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
- "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
+ auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
+ auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
- int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
- int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
- ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
- indIOioC, indIOioD, indWiC, indWoC, indWkD);
+ int kD = INT_ARG(0); // filter(kernel) depth
+ int kH = INT_ARG(1); // filter(kernel) height
+ int kW = INT_ARG(2); // filter(kernel) width
+ int sD = INT_ARG(3); // strides depth
+ int sH = INT_ARG(4); // strides height
+ int sW = INT_ARG(5); // strides width
+ int pD = INT_ARG(6); // paddings depth
+ int pH = INT_ARG(7); // paddings height
+ int pW = INT_ARG(8); // paddings width
+ int dD = INT_ARG(9); // dilations depth
+ int dH = INT_ARG(10); // dilations height
+ int dW = INT_ARG(11); // dilations width
+ int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
+ int extraParam0 = INT_ARG(13);
+ int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
- std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
- {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
- REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
- "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
- expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
- // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
- // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
+ REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
+ REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW MKLDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
- if (!isNCDHW) {
- input = new NDArray(
- input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
- output = new NDArray(
- output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
- }
+ int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
+ int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
- if (isSameMode) // SAME
- ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
+ if(paddingMode) // SAME
+ ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
+
+ auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
+
+ mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode);
+
+ return Status::OK();
+}
+
+//////////////////////////////////////////////////////////////////////
+PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) {
+ auto input = INPUT_VARIABLE(0);
+ auto output = OUTPUT_VARIABLE(0);
+
+ return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
+}
+
+//////////////////////////////////////////////////////////////////////
+PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) {
+
+ auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
+ auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
+ auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
+
+ const int kD = INT_ARG(0); // filter(kernel) depth
+ const int kH = INT_ARG(1); // filter(kernel) height
+ const int kW = INT_ARG(2); // filter(kernel) width
+ const int sD = INT_ARG(3); // strides depth
+ const int sH = INT_ARG(4); // strides height
+ const int sW = INT_ARG(5); // strides width
+ int pD = INT_ARG(6); // paddings depth
+ int pH = INT_ARG(7); // paddings height
+ int pW = INT_ARG(8); // paddings width
+ const int dD = INT_ARG(9); // dilations depth
+ const int dH = INT_ARG(10); // dilations height
+ const int dW = INT_ARG(11); // dilations width
+ const int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
+ const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
+ const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
+
+ REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf());
+ REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
+
+ int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
+ int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
+
+ std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
+ REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
+
+ if(paddingMode) // SAME
+ ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
+
+ auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
+
+ mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode);
+
+ return Status::OK();
+}
+
+//////////////////////////////////////////////////////////////////////
+PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) {
+
+ auto input = INPUT_VARIABLE(0);
+ auto output = OUTPUT_VARIABLE(0);
+
+ return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
+}
- auto poolingMode = PoolingType::AVG_POOL;
-
- dnnl_memory_desc_t empty;
- dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
- dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
- dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
- dnnl::algorithm algorithm;
- mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
- extraParam0, true,
- bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
- algorithm,
- &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
- &user_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
- auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
- pool_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
- auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
- dnnl::stream stream(engine);
- auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
- auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
- auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
- auto pool_src_memory = user_src_memory;
- if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
- pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
- reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
- }
- auto pool_dst_memory = user_dst_memory;
- if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
- pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
- }
- pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
- {DNNL_ARG_DST, pool_dst_memory}});
- if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
- reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
- }
- stream.wait();
-
- if (!isNCDHW) {
- delete input;
- delete output;
- }
-
- return Status::OK();
- }
-
- PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) {
- auto input = INPUT_VARIABLE(0);
- auto output = OUTPUT_VARIABLE(0);
-
- return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
- }
- }
- }
+}
+}
}
\ No newline at end of file
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp
deleted file mode 100644
index 3fd8ab293..000000000
--- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp
+++ /dev/null
@@ -1,154 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-//
-// @author raver119@gmail.com
-//
-
-#include
-#include
-#include
-
-#include
-#include "mkldnnUtils.h"
-#include
-
-using namespace dnnl;
-
-namespace nd4j {
- namespace ops {
- namespace platforms {
- PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) {
- auto input = INPUT_VARIABLE(
- 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
- auto gradO = INPUT_VARIABLE(
- 1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
- auto gradI = OUTPUT_VARIABLE(
- 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
-
- const int kD = INT_ARG(0); // filter(kernel) depth
- const int kH = INT_ARG(1); // filter(kernel) height
- const int kW = INT_ARG(2); // filter(kernel) width
- const int sD = INT_ARG(3); // strides depth
- const int sH = INT_ARG(4); // strides height
- const int sW = INT_ARG(5); // strides width
- int pD = INT_ARG(6); // paddings depth
- int pH = INT_ARG(7); // paddings height
- int pW = INT_ARG(8); // paddings width
- const int dD = INT_ARG(9); // dilations depth
- const int dH = INT_ARG(10); // dilations height
- const int dW = INT_ARG(11); // dilations width
- const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
- int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
- int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
-
- REQUIRE_TRUE(input->rankOf() == 5, 0,
- "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
- REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
- "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
-
- int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
- int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
- ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
- indIOioC, indIOioD, indWiC, indWoC, indWkD);
-
- std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
- {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
- std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
- {bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
- REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
- "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
- expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
- REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
- "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
- expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
-
- if (!isNCDHW) {
- input = new NDArray(input->permute(
- {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
- gradI = new NDArray(gradI->permute(
- {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
- gradO = new NDArray(gradO->permute(
- {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
- }
-
- if (isSameMode) // SAME
- ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
-
-
-
- auto poolingMode = PoolingType::AVG_POOL;
-
- dnnl_memory_desc_t empty;
- dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
- dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
- dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
- dnnl::algorithm algorithm;
- mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
- extraParam0, true,
- bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
- algorithm,
- &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
- &user_diff_src_md, &user_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
- if (input->buffer() == nullptr) {
- pool_src_md = pool_diff_src_md;
- user_src_md = user_diff_src_md;
- }
- auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
- auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
- dnnl::stream stream(engine);
- auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
- auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
- pool_kernel, pool_padding, pool_padding_r);
- auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
- auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
- auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
- auto poolB_src_memory = userB_src_memory;
- if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
- poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
- }
- auto poolB_dst_memory = userB_dst_memory;
- if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
- poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
- reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
- }
- pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
- {DNNL_ARG_DIFF_SRC, poolB_src_memory}});
- if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
- reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
- }
- stream.wait();
-
- if (!isNCDHW) {
- delete input;
- delete gradI;
- delete gradO;
- }
-
- return Status::OK();
- }
-
- PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) {
- auto input = INPUT_VARIABLE(0);
- auto output = OUTPUT_VARIABLE(0);
-
- return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
- }
- }
- }
-}
\ No newline at end of file
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp
index 8974cef14..f63690e81 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp
@@ -37,12 +37,12 @@ namespace platforms {
//////////////////////////////////////////////////////////////////////////
-static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) {
+static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, NDArray* z,
+ const float epsilon, const bool isNCHW) {
- // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
- // also it gives wrong results for formats nhwc and ndhwc
+ // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
- // x -> 2D:nc, 4D:nchw, 5D:ncdhw
+ // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
// mean -> 1D [c]
// variance -> 1D [c]
// weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
@@ -50,8 +50,6 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
const int xRank = x->rankOf();
- auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
-
// input type
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
@@ -63,17 +61,28 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
dnnl::memory::dims dims;
dnnl::memory::format_tag format;
+ const int indHW = isNCHW ? 2 : 1;
+ const int bS = x->sizeAt(0);
+ const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1);
+
+ int iD, iH, iW;
+
if(xRank == 2) {
- dims = {x->sizeAt(0), x->sizeAt(1)};
+ dims = {bS, iC};
format = dnnl::memory::format_tag::nc;
}
else if(xRank == 4) {
- dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
- format = dnnl::memory::format_tag::nchw;
+ iH = x->sizeAt(indHW);
+ iW = x->sizeAt(indHW + 1);
+ dims = {bS, iC, iH, iW};
+ format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
}
else { // xRank = 5
- dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
- format = dnnl::memory::format_tag::ncdhw;
+ iD = x->sizeAt(indHW);
+ iH = x->sizeAt(indHW + 1);
+ iW = x->sizeAt(indHW + 2);
+ dims = {bS, iC, iD, iH, iW};
+ format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
}
// memory descriptors for arrays
@@ -81,29 +90,34 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
- x_user_md.data.format_kind = dnnl_blocked; // overrides format
- x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
- x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
- if(xRank > 2) {
- x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
- x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3];
+ if(x->ews() != 1 || x->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1);
+ if(xRank > 2) {
+ x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3);
+ }
+ if(xRank > 4)
+ x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4);
}
- if(xRank > 4)
- x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
// z, output
- dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, format);
+ dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
- z_user_md.data.format_kind = dnnl_blocked; // overrides format
- z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0];
- z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1];
- if(xRank > 2) {
- z_user_md.data.format_desc.blocking.strides[2] = z->stridesOf()[2];
- z_user_md.data.format_desc.blocking.strides[3] = z->stridesOf()[3];
+ if(z->ews() != 1 || z->ordering() != 'c') {
+ z_user_md.data.format_kind = dnnl_blocked; // overrides format
+ z_user_md.data.format_desc.blocking.strides[0] = z->strideAt(0);
+ z_user_md.data.format_desc.blocking.strides[1] = z->strideAt(1);
+ if(xRank > 2) {
+ z_user_md.data.format_desc.blocking.strides[2] = z->strideAt(2);
+ z_user_md.data.format_desc.blocking.strides[3] = z->strideAt(3);
+ }
+ if(xRank > 4)
+ z_user_md.data.format_desc.blocking.strides[4] = z->strideAt(4);
}
- if(xRank > 4)
- z_user_md.data.format_desc.blocking.strides[4] = z->stridesOf()[4];
+ auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// batchnorm forward description
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
@@ -162,12 +176,11 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
//////////////////////////////////////////////////////////////////////////
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights,
- const float epsilon, NDArray* dLdI, NDArray* dLdW) {
+ NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) {
- // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
- // also it gives wrong results for formats nhwc and ndhwc
+ // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
- // x -> 2D:nc, 4D:nchw, 5D:ncdhw
+ // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
// mean -> 1D [c]
// variance -> 1D [c]
// dLdO - same shape as x
@@ -177,8 +190,6 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
const int xRank = x->rankOf();
- auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
-
// input type
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
@@ -190,17 +201,28 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
dnnl::memory::dims dims;
dnnl::memory::format_tag format;
+ const int indHW = isNCHW ? 2 : 1;
+ const int bS = x->sizeAt(0);
+ const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1);
+
+ int iD, iH, iW;
+
if(xRank == 2) {
- dims = {x->sizeAt(0), x->sizeAt(1)};
+ dims = {bS, iC};
format = dnnl::memory::format_tag::nc;
}
else if(xRank == 4) {
- dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
- format = dnnl::memory::format_tag::nchw;
+ iH = x->sizeAt(indHW);
+ iW = x->sizeAt(indHW + 1);
+ dims = {bS, iC, iH, iW};
+ format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
}
else { // xRank = 5
- dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
- format = dnnl::memory::format_tag::ncdhw;
+ iD = x->sizeAt(indHW);
+ iH = x->sizeAt(indHW + 1);
+ iW = x->sizeAt(indHW + 2);
+ dims = {bS, iC, iD, iH, iW};
+ format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
}
// memory descriptors for arrays
@@ -208,41 +230,49 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
- x_user_md.data.format_kind = dnnl_blocked; // overrides format
- x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
- x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
- if(xRank > 2) {
- x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
- x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3];
+ if(x->ews() != 1 || x->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1);
+ if(xRank > 2) {
+ x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3);
+ }
+ if(xRank > 4)
+ x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4);
}
- if(xRank > 4)
- x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
// dLdO
- dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, format);
+ dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
- dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format
- dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0];
- dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1];
- if(xRank > 2) {
- dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->stridesOf()[2];
- dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->stridesOf()[3];
+ if(dLdO->ews() != 1 || dLdO->ordering() != 'c') {
+ dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format
+ dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->strideAt(0);
+ dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->strideAt(1);
+ if(xRank > 2) {
+ dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->strideAt(2);
+ dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->strideAt(3);
+ }
+ if(xRank > 4)
+ dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->strideAt(4);
}
- if(xRank > 4)
- dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4];
// dLdI
- dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, format);
+ dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
- dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format
- dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0];
- dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1];
- if(xRank > 2) {
- dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->stridesOf()[2];
- dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->stridesOf()[3];
+ if(dLdI->ews() != 1 || dLdI->ordering() != 'c') {
+ dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format
+ dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->strideAt(0);
+ dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->strideAt(1);
+ if(xRank > 2) {
+ dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->strideAt(2);
+ dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->strideAt(3);
+ }
+ if(xRank > 4)
+ dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->strideAt(4);
}
- if(xRank > 4)
- dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4];
+
+ auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// batchnorm forward description
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
@@ -331,7 +361,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
// dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m))
// dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) )
- std::vector axes = {1};
+ std::vector axes = isNCHW ? std::vector{1} : std::vector{xRank - 1};
const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes);
// inversed batch size 1 / N
@@ -377,7 +407,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
PLATFORM_IMPL(batchnorm, ENGINE_CPU) {
- auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
+ auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
auto mean = INPUT_VARIABLE(1); // [c]
auto variance = INPUT_VARIABLE(2); // [c]
NDArray* gamma = nullptr; // [c]
@@ -436,31 +466,19 @@ PLATFORM_IMPL(batchnorm, ENGINE_CPU) {
(*weights)({1,2, 0,0}).assign(0);
}
- if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc
- std::vector permut = inRank == 4 ? std::vector({0,3,1,2}) : std::vector({0,4,1,2,3});
- input = new NDArray(input->permute(permut));
- output = new NDArray(output->permute(permut));
- }
+ const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
- batchnormMKLDNN(input, mean, variance, weights, epsilon, output);
+ batchnormMKLDNN(input, mean, variance, weights, output, epsilon, isNCHW);
delete weights;
- if(axes[0] == inRank - 1 && inRank > 2) {
- delete input;
- delete output;
- }
-
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(batchnorm, ENGINE_CPU) {
- // we don't want to use mkldnn if cpu doesn't support avx/avx2
- // if (::optimalLevel() < 2)
- // return false;
- auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
+ auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
auto mean = INPUT_VARIABLE(1); // [c]
auto variance = INPUT_VARIABLE(2); // [c]
NDArray* gamma = nullptr; // [c]
@@ -634,7 +652,7 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
- NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
+ NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
NDArray* mean = INPUT_VARIABLE(1); // [c]
NDArray* variance = INPUT_VARIABLE(2); // [c]
NDArray* gamma = nullptr; // [c]
@@ -702,15 +720,9 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
(*weights)({1,2, 0,0}).assign(0);
}
+ const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
- if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc
- std::vector permut = inRank == 4 ? std::vector({0,3,1,2}) : std::vector({0,4,1,2,3});
- input = new NDArray(input->permute(permut));
- dLdO = new NDArray(dLdO->permute(permut));
- dLdI = new NDArray(dLdI->permute(permut));
- }
-
- batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW);
+ batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
*dLdM = 0;
*dLdV = 0;
@@ -725,17 +737,12 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
delete dLdW;
}
- if(axes[0] == inRank - 1 && inRank > 2) {
- delete input;
- delete dLdO;
- delete dLdI;
- }
-
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(batchnorm_bp, ENGINE_CPU) {
+
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
NDArray* mean = INPUT_VARIABLE(1); // [c]
NDArray* variance = INPUT_VARIABLE(2); // [c]
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp
index 559edf2cd..2d88a73ef 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp
@@ -17,6 +17,7 @@
//
// @author saudet
// @author raver119@gmail.com
+// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include
@@ -33,6 +34,298 @@ namespace nd4j {
namespace ops {
namespace platforms {
+//////////////////////////////////////////////////////////////////////
+static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
+ const NDArray *bias, NDArray *output,
+ const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
+ const int paddingMode, const int isNCHW) {
+
+ // weights [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW]
+
+ int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
+ int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
+
+ const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
+
+ dnnl::memory::dims strides = { sH, sW };
+ dnnl::memory::dims padding = { pH, pW };
+ dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
+ dnnl::memory::dims dilation = { dH-1, dW-1};
+
+ auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
+ dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
+
+ dnnl::memory::dims xDims = {bS, iC, iH, iW};
+ dnnl::memory::dims wDims = {oC, iC, kH, kW};
+ dnnl::memory::dims zDims = {bS, oC, oH, oW};
+
+ auto type = dnnl::memory::data_type::f32;
+
+ // memory descriptors for arrays
+
+ // input
+ dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
+ }
+
+ // weights
+ dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
+ w_user_md.data.format_kind = dnnl_blocked; // overrides format
+ w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
+ w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
+ w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
+ w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
+
+ // bias
+ dnnl::memory::desc b_mkl_md;
+ if(bias != nullptr)
+ b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
+
+ // output
+ dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
+ if(output->ews() != 1 || output->ordering() != 'c') {
+ z_user_md.data.format_kind = dnnl_blocked; // overrides format
+ z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
+ z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
+ z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
+ z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
+ }
+
+ auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
+
+ // operation primitive description
+ dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
+ dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine);
+
+ // arguments (memory buffers) necessary for calculations
+ std::unordered_map args;
+
+ dnnl::stream stream(engine);
+
+ // provide memory buffers and check whether reorder is required
+
+ // input
+ auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
+ const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
+ auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
+ if (xReorder)
+ dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
+ args[DNNL_ARG_SRC] = x_mkl_mem;
+
+ // weights
+ auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
+ const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
+ auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
+ if (wReorder)
+ dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
+ args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
+
+ // bias
+ if(bias != nullptr) {
+ auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
+ args[DNNL_ARG_BIAS] = b_mkl_mem;
+ }
+
+ // output
+ auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
+ const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
+ auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
+ args[DNNL_ARG_DST] = z_mkl_mem;
+
+ // run calculations
+ dnnl::convolution_forward(op_prim_desc).execute(stream, args);
+
+ // reorder outputs if necessary
+ if (zReorder)
+ dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
+
+ stream.wait();
+ // shape::printArray(z_mkl_mem.map_data(),8);
+}
+
+//////////////////////////////////////////////////////////////////////
+static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
+ NDArray *gradI, NDArray *gradW, NDArray *gradB,
+ const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
+ const int paddingMode, const int isNCHW) {
+
+ // weights/gradW [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW]
+
+ int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
+ int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
+
+ const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
+
+ dnnl::memory::dims strides = { sH, sW };
+ dnnl::memory::dims padding = { pH, pW };
+ dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
+ dnnl::memory::dims dilation = { dH-1, dW-1};
+
+ auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
+ dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
+
+ dnnl::memory::dims xDims = {bS, iC, iH, iW};
+ dnnl::memory::dims wDims = {oC, iC, kH, kW};
+ dnnl::memory::dims zDims = {bS, oC, oH, oW};
+
+ auto type = dnnl::memory::data_type::f32;
+
+ // memory descriptors for arrays
+
+ // input
+ dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
+ }
+
+ // weights
+ dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
+ w_user_md.data.format_kind = dnnl_blocked; // overrides format
+ w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
+ w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
+ w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
+ w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
+
+ // gradO
+ dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
+ if(gradO->ews() != 1 || gradO->ordering() != 'c') {
+ gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
+ gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
+ gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
+ gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
+ }
+
+ // gradI
+ dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ if(gradI->ews() != 1 || gradI->ordering() != 'c') {
+ gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
+ gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
+ gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
+ gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
+ }
+
+ // gradW
+ dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat);
+ gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
+ gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(2);
+ gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
+ gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
+
+ // gradB
+ dnnl::memory::desc gradB_mkl_md;
+ if(gradB != nullptr)
+ gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
+
+ auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
+
+ // forward primitive description
+ dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
+ dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
+
+ // backward data primitive description
+ dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
+ dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
+
+ // backward weights primitive description
+ dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
+ dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
+
+ // arguments (memory buffers) necessary for calculations
+ std::unordered_map args;
+
+ dnnl::stream stream(engine);
+
+ // provide memory buffers and check whether reorder is required
+
+ // input
+ auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
+ const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
+ auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
+ if (xReorder)
+ dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
+ args[DNNL_ARG_SRC] = x_mkl_mem;
+
+ // weights
+ auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
+ const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
+ auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
+ if (wReorder)
+ dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
+ args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
+
+ // gradO
+ auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
+ const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ if (gradOReorderW)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
+ if (gradOReorderD)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
+
+ // gradI
+ auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
+ const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
+ auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
+ args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
+
+ // gradW
+ auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
+ const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
+ auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
+ args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
+
+ // gradB
+ if(gradB != nullptr) {
+ auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
+ args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
+ }
+
+ // run backward data calculations
+ dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
+
+ if(gradOReorderW || gradOReorderD)
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
+
+ // run backward weights calculations
+ dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
+
+ // reorder gradI if necessary
+ if (gradIReorder)
+ dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
+ if (gradWReorder)
+ dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
+
+ stream.wait();
+
+ // shape::printArray(z_mkl_mem.map_data(),8);
+}
+
+/*
//////////////////////////////////////////////////////////////////////
static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
@@ -46,37 +339,37 @@ static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, cons
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
dnnl_memory_desc_t empty;
- dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
- dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
+ dnnl::memory::desc x_mkl_md(empty), w_mkl_md(empty), b_mkl_md(empty), z_mkl_md(empty);
+ dnnl::memory::desc x_user_md(empty), w_user_md(empty), b_user_md(empty), z_user_md(empty);
- dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
+ dnnl::memory::dims strides, padding, padding_r, dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
bias, output,
- &conv_src_md, nullptr, &conv_weights_md, nullptr,
- &conv_bias_md, &conv_dst_md,
- &user_src_md, nullptr, &user_weights_md, nullptr,
- &user_bias_md, &user_dst_md,
- conv_strides, conv_padding, conv_padding_r, conv_dilation);
+ &x_mkl_md, nullptr, &w_mkl_md, nullptr,
+ &b_mkl_md, &z_mkl_md,
+ &x_user_md, nullptr, &w_user_md, nullptr,
+ &b_user_md, &z_user_md,
+ strides, padding, padding_r, dilation);
auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward,
- algorithm::convolution_auto, conv_src_md,
- conv_weights_md, conv_bias_md,
- conv_dst_md, conv_strides, conv_dilation, conv_padding,
- conv_padding_r)
+ algorithm::convolution_auto, x_mkl_md,
+ w_mkl_md, b_mkl_md,
+ z_mkl_md, strides, dilation, padding,
+ padding_r)
: convolution_forward::desc(prop_kind::forward,
- algorithm::convolution_auto, conv_src_md,
- conv_weights_md,
- conv_dst_md, conv_strides, conv_dilation, conv_padding,
- conv_padding_r);
+ algorithm::convolution_auto, x_mkl_md,
+ w_mkl_md,
+ z_mkl_md, strides, dilation, padding,
+ padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
- auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer());
- auto user_weights_memory = dnnl::memory(user_weights_md, engine,
+ auto user_src_memory = dnnl::memory(x_user_md, engine, const_cast(input)->buffer());
+ auto user_weights_memory = dnnl::memory(w_user_md, engine,
const_cast(weights)->buffer());
- auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
+ auto user_dst_memory = dnnl::memory(z_user_md, engine, output->buffer());
auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
@@ -239,13 +532,16 @@ static void conv2dBpMKLDNN(nd4j::graph::Context &block,
}
}
+*/
+
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv2d, ENGINE_CPU) {
- auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
- auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
- auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
- auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
+ auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
+ auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
+ auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
+
+ auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
@@ -254,21 +550,29 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) {
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
- bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
+ bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width
- conv2dMKLDNN(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
+ int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
+ int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
+
+ ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
+
+ std::vector expectedWeightsShape = {kH, kW, iC, oC};
+ REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
+ if (bias)
+ REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
+
+ conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
return Status::OK();
}
-PLATFORM_CHECK(conv2d, ENGINE_CPU) {
- // we don't want to use mkldnn if cpu doesn't support avx/avx2
- if (::optimalLevel() < 2)
- return false;
+PLATFORM_CHECK(conv2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
@@ -280,10 +584,10 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) {
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
- auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
+ auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
- auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
- auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
+ auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
+ auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
@@ -297,19 +601,33 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
- int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
- int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
+ int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
+ int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
- REQUIRE_TRUE(input->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",input->rankOf());
- REQUIRE_TRUE(weights->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",weights->rankOf());
- REQUIRE_TRUE(gradO->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",gradO->rankOf());
+ int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
+ int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
- conv2dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
+ int trueoH, trueoW; // true output height, width
+ ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
+
+ if(paddingMode) // SAME
+ ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
+
+ std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
+ std::vector expectedWeightsShape = {kH, kW, iC, oC};
+ REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
+ REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
+ if(bias)
+ REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
+
+ conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
return Status::OK();
}
PLATFORM_CHECK(conv2d_bp, ENGINE_CPU) {
+
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp
index 747d84c36..7c10b0d1e 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp
@@ -33,6 +33,314 @@ namespace nd4j {
namespace ops {
namespace platforms {
+//////////////////////////////////////////////////////////////////////
+static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
+ const NDArray *bias, NDArray *output,
+ const int kD, const int kH, const int kW,
+ const int sD, const int sH, const int sW,
+ const int pD, const int pH, const int pW,
+ const int dD, const int dH, const int dW,
+ const int paddingMode, const int isNCDHW) {
+
+ // weights [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW]
+
+ int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
+ int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
+
+ // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
+
+ dnnl::memory::dims strides = {sD, sH, sW};
+ dnnl::memory::dims padding = {pD, pH, pW};
+ // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
+ dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW};
+ dnnl::memory::dims dilation = {dD-1, dH-1, dW-1};
+
+ auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
+ dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
+
+ dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
+ dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
+ dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
+
+ auto type = dnnl::memory::data_type::f32;
+
+ // memory descriptors for arrays
+
+ // input
+ dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
+ x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
+ }
+
+ // weights
+ dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
+ w_user_md.data.format_kind = dnnl_blocked; // overrides format
+ w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
+ w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
+ w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
+ w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
+ w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
+
+ // bias
+ dnnl::memory::desc b_mkl_md;
+ if(bias != nullptr)
+ b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
+
+ // output
+ dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
+ if(output->ews() != 1 || output->ordering() != 'c') {
+ z_user_md.data.format_kind = dnnl_blocked; // overrides format
+ z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
+ z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
+ z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
+ z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
+ z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4);
+ }
+
+ auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
+
+ // operation primitive description
+ dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
+ dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine);
+
+ // arguments (memory buffers) necessary for calculations
+ std::unordered_map args;
+
+ dnnl::stream stream(engine);
+
+ // provide memory buffers and check whether reorder is required
+
+ // input
+ auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
+ const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
+ auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
+ if (xReorder)
+ dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
+ args[DNNL_ARG_SRC] = x_mkl_mem;
+
+ // weights
+ auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
+ const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
+ auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
+ if (wReorder)
+ dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
+ args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
+
+ // bias
+ if(bias != nullptr) {
+ auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
+ args[DNNL_ARG_BIAS] = b_mkl_mem;
+ }
+
+ // output
+ auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
+ const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
+ auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
+ args[DNNL_ARG_DST] = z_mkl_mem;
+
+ // run calculations
+ dnnl::convolution_forward(op_prim_desc).execute(stream, args);
+
+ // reorder outputs if necessary
+ if (zReorder)
+ dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
+
+ stream.wait();
+}
+
+//////////////////////////////////////////////////////////////////////
+static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
+ NDArray *gradI, NDArray *gradW, NDArray *gradB,
+ const int kD, const int kH, const int kW,
+ const int sD, const int sH, const int sW,
+ const int pD, const int pH, const int pW,
+ const int dD, const int dH, const int dW,
+ const int paddingMode, const int isNCDHW) {
+
+ // weights/gradW [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW]
+
+ int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
+ int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
+
+ // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
+
+ dnnl::memory::dims strides = {sD, sH, sW};
+ dnnl::memory::dims padding = {pD, pH, pW};
+ // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
+ dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW};
+ dnnl::memory::dims dilation = {dD-1, dH-1, dW-1};
+
+ auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
+ dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
+
+ dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
+ dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
+ dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
+
+ auto type = dnnl::memory::data_type::f32;
+
+ // memory descriptors for arrays
+
+ // input
+ dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
+ x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
+ }
+
+ // weights
+ dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
+ w_user_md.data.format_kind = dnnl_blocked; // overrides format
+ w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
+ w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
+ w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
+ w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
+ w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
+
+ // gradO
+ dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
+ if(gradO->ews() != 1 || gradO->ordering() != 'c') {
+ gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
+ gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
+ gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
+ gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
+ gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4);
+ }
+
+ // gradI
+ dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ if(gradI->ews() != 1 || gradI->ordering() != 'c') {
+ gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
+ gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
+ gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
+ gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
+ gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4);
+ }
+
+ // gradW
+ dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat);
+ gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
+ gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3);
+ gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
+ gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
+ gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2);
+
+ // gradB
+ dnnl::memory::desc gradB_mkl_md;
+ if(gradB != nullptr)
+ gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
+
+ auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
+
+ // forward primitive description
+ dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
+ dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
+
+ // backward data primitive description
+ dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
+ dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
+
+ // backward weights primitive description
+ dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
+ dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
+
+ // arguments (memory buffers) necessary for calculations
+ std::unordered_map args;
+
+ dnnl::stream stream(engine);
+
+ // provide memory buffers and check whether reorder is required
+
+ // input
+ auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
+ const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
+ auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
+ if (xReorder)
+ dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
+ args[DNNL_ARG_SRC] = x_mkl_mem;
+
+ // weights
+ auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
+ const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
+ auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
+ if (wReorder)
+ dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
+ args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
+
+ // gradO
+ auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
+ const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ if (gradOReorderW)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
+ if (gradOReorderD)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
+
+ // gradI
+ auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
+ const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
+ auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
+ args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
+
+ // gradW
+ auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
+ const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
+ auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
+ args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
+
+ // gradB
+ if(gradB != nullptr) {
+ auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
+ args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
+ }
+
+ // run backward data calculations
+ dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
+
+ if(gradOReorderW || gradOReorderD)
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
+
+ // run backward weights calculations
+ dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
+
+ // reorder gradI if necessary
+ if (gradIReorder)
+ dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
+ if (gradWReorder)
+ dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
+
+ stream.wait();
+
+ // shape::printArray(z_mkl_mem.map_data(),8);
+}
+
+
+/*
//////////////////////////////////////////////////////////////////////
static void conv3dMKLDNN(nd4j::graph::Context &block,
const NDArray *input, const NDArray *weights, const NDArray *bias,
@@ -225,6 +533,7 @@ static void conv3dBpMKLDNN(nd4j::graph::Context &block,
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory);
}
}
+*/
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
@@ -256,24 +565,20 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
- std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
- REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
+ std::vector expectedWeightsShape = {kD, kH, kW, iC, oC};
+ REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if (paddingMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
- conv3dMKLDNN(block, input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
+ conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
return Status::OK();
}
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
- // we don't want to use mkldnn if cpu doesn't support avx/avx2
- if (::optimalLevel() < 2)
- return false;
-
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
@@ -284,6 +589,7 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
+
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
@@ -322,20 +628,19 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
int trueoD, trueoH, trueoW; // true output depth/height/width
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
- std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
- std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
- REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
- REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
+ std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2});
+ std::vector expectedWeightsShape = {kD, kH, kW, iC, oC};
+ REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
+ REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
- conv3dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
+ conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
return Status::OK();
}
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
-
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp
index d95052c5a..1879ef8fb 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp
@@ -34,17 +34,13 @@ namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
- const int paddingMode) {
+ const int paddingMode, const bool isNCHW) {
- // input [bS, iC, iH, iW] nchw, mkl doesn't support format nhwc
- // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
- // bias [oC], may be nullptr
-
- // output [bS, oC, oH, oW] nchw, mkl doesn't support format nhwc
+ // weights [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
- ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims padding = { pH, pW };
@@ -80,8 +76,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
else
zType = dnnl::memory::data_type::s32;
-
- dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
+ dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW};
@@ -93,20 +88,22 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
- x_user_md.data.format_kind = dnnl_blocked; // overrides format
- x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
- x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
- x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
- x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
+ }
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
- w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
- w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
- w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
- w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
+ w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
+ w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
+ w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
+ w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
// bias
dnnl::memory::desc b_mkl_md;
@@ -116,11 +113,13 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
- z_user_md.data.format_kind = dnnl_blocked; // overrides format
- z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
- z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
- z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
- z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3];
+ if(output->ews() != 1 || output->ordering() != 'c') {
+ z_user_md.data.format_kind = dnnl_blocked; // overrides format
+ z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
+ z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
+ z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
+ z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
+ }
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@@ -179,21 +178,19 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
//////////////////////////////////////////////////////////////////////////
static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
- const int paddingMode) {
+ const int paddingMode, const bool isNCHW) {
- // input and gradI [bS, iC, iH, iW], mkl doesn't support ndhwc format
- // weights and gradW [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
- // gradB [oC], may be nullptr
- // gradO [bS, oC, oH, oW]
+ // weights and gradW [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
- ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims padding = { pH, pW };
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dH-1, dW-1 };
+
// input type
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// weights type
@@ -207,7 +204,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// gradB type
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
- dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
+ dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW};
@@ -219,54 +216,59 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
- x_user_md.data.format_kind = dnnl_blocked; // overrides format
- x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
- x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
- x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
- x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
+ }
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
- w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
- w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
- w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
- w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
+ w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
+ w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
+ w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
+ w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
- gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
- gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
- gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
- gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
- gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
+ if(gradO->ews() != 1 || gradO->ordering() != 'c') {
+ gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
+ gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
+ gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
+ gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
+ }
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
- gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
- gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
- gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
- gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
- gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
+ if(gradI->ews() != 1 || gradI->ordering() != 'c') {
+ gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
+ gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
+ gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
+ gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
+ }
// gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
- gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
- gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
- gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
- gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
+ gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
+ gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3);
+ gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
+ gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
// gradB
dnnl::memory::desc gradB_mkl_md;
if(gradB != nullptr)
gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
-
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward primitive description
@@ -306,11 +308,15 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
- const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
- auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
- if (gradOReorder)
- dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
- args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
+ const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ if (gradOReorderW)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
+ if (gradOReorderD)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
// gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
@@ -333,6 +339,9 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// run backward data calculations
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
+ if(gradOReorderW || gradOReorderD)
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
+
// run backward weights calculations
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
@@ -385,32 +394,12 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
}
- // mkl supports only [oC, iC, kH, kW] format for weights
- weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
-
- // mkl supports only NCHW
- if(!isNCHW) {
- input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
- output = new NDArray(output->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
- }
-
- deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
-
- delete weights;
-
- if(!isNCHW) {
- delete input;
- delete output;
- }
+ deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
return Status::OK();
}
PLATFORM_CHECK(deconv2d, ENGINE_CPU) {
- // we don't want to use mkldnn if cpu doesn't support avx/avx2
- // if (::optimalLevel() < 2)
- // return false;
-
auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
@@ -481,27 +470,7 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
}
- // mkl supports only [oC, iC, kH, kW] for weights
- weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
- gradW = new NDArray(gradW->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
-
- // mkl supports NCHW format only
- if(!isNCHW) {
- input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
- gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
- gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
- }
-
- deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
-
- delete weights;
- delete gradW;
-
- if(!isNCHW) {
- delete input;
- delete gradI;
- delete gradO;
- }
+ deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
return Status::OK();
}
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp
index 90ddb828e..7c6582ab4 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp
@@ -33,7 +33,8 @@ namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
- const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
+ const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
+ const bool isNCHW) {
// gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
@@ -51,7 +52,7 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// gradI type
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
- dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
+ dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW};
@@ -67,29 +68,32 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
- w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
- w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
- w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
- w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
+ w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
+ w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
+ w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
+ w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
- gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
- gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
- gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
- gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
- gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
+ if(gradO->ews() != 1 || gradO->ordering() != 'c') {
+ gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
+ gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
+ gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
+ gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
+ }
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
- gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
- gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
- gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
- gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
- gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
-
+ if(gradI->ews() != 1 || gradI->ordering() != 'c') {
+ gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
+ gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
+ gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
+ gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
+ }
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@@ -166,9 +170,9 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
const int rank = gradO->rankOf();
- REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
- REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf());
- REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
+ REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
+ REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf());
+ REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
int indIOioC, indIiH, indWoC(3), indOoH;
if(!isNCHW) {
@@ -193,29 +197,29 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector expectedWeightsShape = {kH, kW, iC, oC};
- REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
- REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
+ REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
+ REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
- // mkl supports only [oC, iC, kH, kW] for weights
- weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
+ // // mkl supports only [oC, iC, kH, kW] for weights
+ // weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
- // mkl supports NCHW format only
- if(!isNCHW) {
- gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
- gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
- }
+ // // mkl supports NCHW format only
+ // if(!isNCHW) {
+ // gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
+ // gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
+ // }
- deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW);
+ deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW);
- delete weights;
+ // delete weights;
- if(!isNCHW) {
- delete gradI;
- delete gradO;
- }
+ // if(!isNCHW) {
+ // delete gradI;
+ // delete gradO;
+ // }
// ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp
index a678e0185..5daab8228 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp
@@ -34,17 +34,14 @@ namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW,
- const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
+ const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
+ const bool isNCDHW) {
- // input [bS, iD, iH, iW, iC] ncdhw, mkl doesn't support format ndhwc
- // weights [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
- // bias [oC], may be nullptr
-
- // output [bS, oD, oH, oW, oC] ncdhw, mkl doesn't support format ndhwc
+ // weights [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
- ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
+ ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
dnnl::memory::dims strides = { sD, sH, sW };
dnnl::memory::dims padding = { pD, pH, pW };
@@ -80,8 +77,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
else
zType = dnnl::memory::data_type::s32;
-
- dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw;
+ dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
@@ -93,22 +89,24 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
- x_user_md.data.format_kind = dnnl_blocked; // overrides format
- x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
- x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
- x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
- x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
- x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
+ x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
+ }
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
- w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
- w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
- w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
- w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
- w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
+ w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
+ w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4);
+ w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
+ w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
+ w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
// bias
dnnl::memory::desc b_mkl_md;
@@ -118,12 +116,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
- z_user_md.data.format_kind = dnnl_blocked; // overrides format
- z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
- z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
- z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
- z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3];
- z_user_md.data.format_desc.blocking.strides[4] = output->stridesOf()[4];
+ if(output->ews() !=1 || output->ordering() != 'c') {
+ z_user_md.data.format_kind = dnnl_blocked; // overrides format
+ z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
+ z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
+ z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
+ z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
+ z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4);
+ }
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@@ -184,16 +184,14 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
const int kD, const int kH, const int kW,
const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW,
- const int dD, const int dH, const int dW) {
+ const int dD, const int dH, const int dW,
+ const bool isNCDHW) {
- // input and gradI [bS, iD, iH, iW, iC], mkl doesn't support ndhwc format
- // weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
- // gradB [oC], may be nullptr
- // gradO [bS, oD, oH, oW, oC]
+ // weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
- ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
+ ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
dnnl::memory::dims strides = { sD, sH, sW };
dnnl::memory::dims padding = { pD, pH, pW };
@@ -213,7 +211,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// gradB type
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
- dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw; // isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
+ dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
@@ -225,52 +223,58 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
- x_user_md.data.format_kind = dnnl_blocked; // overrides format
- x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
- x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
- x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
- x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
- x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
+ x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
+ }
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
- w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
- w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
- w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
- w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
- w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
+ w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
+ w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4);
+ w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
+ w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
+ w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
- gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
- gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
- gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
- gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
- gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
- gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4];
+ if(gradO->ews() != 1 || gradO->ordering() != 'c') {
+ gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
+ gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
+ gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
+ gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
+ gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4);
+ }
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
- gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
- gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
- gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
- gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
- gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
- gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4];
+ if(gradI->ews() != 1 || gradI->ordering() != 'c') {
+ gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
+ gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
+ gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
+ gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
+ gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4);
+ }
// gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
- gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
- gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
- gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
- gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
- gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4];
+ gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
+ gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(4);
+ gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
+ gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
+ gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2);
// gradB
dnnl::memory::desc gradB_mkl_md;
@@ -317,11 +321,15 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
- const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
- auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
- if (gradOReorder)
- dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
- args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
+ const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ if (gradOReorderW)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
+ if (gradOReorderD)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
// gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
@@ -344,6 +352,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// run backward data calculations
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
+ if(gradOReorderW || gradOReorderD)
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
+
// run backward weights calculations
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
@@ -400,32 +411,12 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
}
- // mkl supports only [oC, iC, kD, kH, kW] format for weights
- weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
-
- // mkl supports only NCDHW
- if(!isNCDHW) {
- input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
- output = new NDArray(output->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
- }
-
- deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW);
-
- delete weights;
-
- if(!isNCDHW) {
- delete input;
- delete output;
- }
+ deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
return Status::OK();
}
PLATFORM_CHECK(deconv3d, ENGINE_CPU) {
- // we don't want to use mkldnn if cpu doesn't support avx/avx2
- // if (::optimalLevel() < 2)
- // return false;
-
auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
@@ -499,27 +490,7 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
- // mkl supports only [oC, iC, kD, kH, kW] for weights
- weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
- gradW = new NDArray(gradW->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
-
- // mkl supports NCDHW format only
- if(!isNCDHW) {
- input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
- gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
- gradO = new NDArray(gradO->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
- }
-
- deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW);
-
- delete weights;
- delete gradW;
-
- if(!isNCDHW) {
- delete input;
- delete gradI;
- delete gradO;
- }
+ deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
return Status::OK();
}
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp
index fc7a1e9e3..4da2c2cb0 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp
@@ -86,7 +86,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
else
zType = dnnl::memory::data_type::s32;
- dnnl::memory::format_tag xzFrmat = dnnl::memory::format_tag::nchw;
+ dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW};
@@ -98,11 +98,13 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
- x_user_md.data.format_kind = dnnl_blocked; // overrides format NHWC -> NCHW
- x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
- x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3);
- x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
- x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); // do permutation NHWC -> NCHW
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
+ }
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
@@ -122,11 +124,13 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat);
- z_user_md.data.format_kind = dnnl_blocked; // overrides format
- z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
- z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 : 3);
- z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1);
- z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2);
+ if(output->ews() != 1 || output->ordering() != 'c') {
+ z_user_md.data.format_kind = dnnl_blocked; // overrides format
+ z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
+ z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); // do permutation NHWC -> NCHW
+ z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
+ z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
+ }
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@@ -219,7 +223,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// gradB type
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
- dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
+ dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
dnnl::memory::dims xDims = {bS, iC, iH, iW};
@@ -230,12 +234,14 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
- dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
- x_user_md.data.format_kind = dnnl_blocked; // overrides format
- x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
- x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3);
- x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
- x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
+ dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
+ }
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
@@ -249,21 +255,25 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
- dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
- gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
- gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
- gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 : 3);
- gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1);
- gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2);
+ dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat);
+ if(gradO->ews() != 1 || gradO->ordering() != 'c') {
+ gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
+ gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
+ gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
+ gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
+ }
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
- dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
- gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
- gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
- gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 : 3);
- gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1);
- gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2);
+ dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat);
+ if(gradI->ews() != 1 || gradI->ordering() != 'c') {
+ gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
+ gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
+ gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
+ gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
+ }
// gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
@@ -319,11 +329,15 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
- const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
- auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
- if (gradOReorder)
- dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
- args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
+ const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ if (gradOReorderW)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
+ if (gradOReorderD)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
// gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
@@ -346,6 +360,9 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// run backward data calculations
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
+ if(gradOReorderW || gradOReorderD)
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
+
// run backward weights calculations
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
@@ -401,9 +418,6 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) {
//////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
- // we don't want to use mkldnn if cpu doesn't support avx/avx2
- if (::optimalLevel() < 2)
- return false;
auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp
index 69aee8fad..3e7979f2f 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp
@@ -17,6 +17,7 @@
//
// @author saudet
// @author raver119@gmail.com
+// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include
@@ -33,105 +34,38 @@ namespace nd4j {
namespace ops {
namespace platforms {
+
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
+
auto input = INPUT_VARIABLE(0);
-
- REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
- input->rankOf());
-
- // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
- auto argI = *(block.getIArguments());
auto output = OUTPUT_VARIABLE(0);
- const auto kH = INT_ARG(0);
- const auto kW = INT_ARG(1);
- const auto sH = INT_ARG(2);
- const auto sW = INT_ARG(3);
- int pH = INT_ARG(4);
- int pW = INT_ARG(5);
- const auto dH = INT_ARG(6);
- const auto dW = INT_ARG(7);
- const auto isSameMode = static_cast(INT_ARG(8));
+ REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D MKLDNN OP: input array should have rank of 4, but got %i instead", input->rankOf());
- REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
- dH, dW);
+ // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
+ const int kH = INT_ARG(0);
+ const int kW = INT_ARG(1);
+ const int sH = INT_ARG(2);
+ const int sW = INT_ARG(3);
+ int pH = INT_ARG(4);
+ int pW = INT_ARG(5);
+ const int dH = INT_ARG(6);
+ const int dW = INT_ARG(7);
+ const int paddingMode = INT_ARG(8);
+ // const int extraParam0 = INT_ARG(9);
+ const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW
- int oH = 0;
- int oW = 0;
+ REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
- int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
+ int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
+ int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
- const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
- const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
-
- if (!isNCHW) {
- input = new NDArray(
- input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
- output = new NDArray(
- output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
- }
-
- ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
-
- if (isSameMode)
+ if (paddingMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
- const int bS = input->sizeAt(0);
- const int iC = input->sizeAt(1);
- const int oC = output->sizeAt(1);
-
- auto poolingMode = PoolingType::MAX_POOL;
- int extraParam0 = 1;
-
- dnnl_memory_desc_t empty;
- dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
- dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
- dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
- dnnl::algorithm algorithm;
-
- mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
- true,
- bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
- algorithm,
- &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
- &user_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
-
- auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
- pool_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
-
- auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
- auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
- auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
- auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
-
- auto pool_src_memory = user_src_memory;
- dnnl::stream stream(engine);
- if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
- pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
- reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
- }
-
- auto pool_dst_memory = user_dst_memory;
- if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
- pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
- }
-
- pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
- {DNNL_ARG_DST, pool_dst_memory}});
-
- if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
- reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
- }
-
- stream.wait();
-
- if (!isNCHW) {
- delete input;
- delete output;
- }
+ mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max);
return Status::OK();
}
@@ -159,117 +93,24 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
- int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
- int extraParam0 = INT_ARG(9);
- int isNCHW =
- block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
+ int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
+ // int extraParam0 = INT_ARG(9);
+ int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
- REQUIRE_TRUE(input->rankOf() == 4, 0,
- "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
- REQUIRE_TRUE(dH != 0 && dW != 0, 0,
- "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
+ REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf());
+ REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
- ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
- indIiH, indWiC, indWoC, indWkH, indOoH);
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
- std::string expectedGradOShape = ShapeUtils::shapeAsString(
- ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
- std::string expectedGradIShape = ShapeUtils::shapeAsString(
- ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
- REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
- "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
- expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
- REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
- "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
- expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
+ std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1});
+ REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
-
- if (!isNCHW) {
- input = new NDArray(input->permute(
- {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
- gradI = new NDArray(gradI->permute(
- {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
- gradO = new NDArray(gradO->permute(
- {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
- }
-
- if (isSameMode) // SAME
+ if (paddingMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
- auto poolingMode = PoolingType::MAX_POOL;
-
- dnnl_memory_desc_t empty;
- dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
- dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
- dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
- dnnl::algorithm algorithm;
-
- mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
- true,
- bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
- &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
- &user_diff_src_md, &user_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
-
- // input is sometimes null, so we can't rely on pool_src_md being valid
- auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
- input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
- pool_dst_md, pool_strides, pool_kernel, pool_padding,
- pool_padding_r);
-
- auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
- dnnl::stream stream(engine);
- auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
-
- auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
-
- auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
- auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
- auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
-
- auto poolB_src_memory = userB_src_memory;
- if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
- poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
- }
-
- auto poolB_dst_memory = userB_dst_memory;
- if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
- poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
- reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
- }
-
- auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
- auto pool_src_memory = user_src_memory;
- if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
- pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
- reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
- }
-
- auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
- auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
-
- pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
- {DNNL_ARG_DST, pool_dst_memory},
- {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
- // probably wrong, fix that
- pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
- {DNNL_ARG_WORKSPACE, pool_workspace_memory},
- {DNNL_ARG_DIFF_SRC, poolB_src_memory}});
-
- if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
- reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
- }
-
- stream.wait();
-
- if (!isNCHW) {
- delete input;
- delete gradI;
- delete gradO;
- }
+ mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max);
return Status::OK();
}
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp
index a37422c55..7f6e95418 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp
@@ -16,6 +16,7 @@
//
// @author raver119@gmail.com
+// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include
@@ -34,10 +35,9 @@ namespace platforms {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
- auto input = INPUT_VARIABLE(
- 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
- auto output = OUTPUT_VARIABLE(
- 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
+
+ auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
+ auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height
@@ -51,95 +51,24 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
- int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
+ int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
- int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
+ int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
- REQUIRE_TRUE(input->rankOf() == 5, 0,
- "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
- input->rankOf());
- REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
- "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
+ REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
+ REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
- ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
- indIOioC, indIOioD, indWiC, indWoC, indWkD);
+ ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
- std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
- {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
- REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
- "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
- expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
- // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
- // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
+ if(paddingMode) // SAME
+ ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
- if (!isNCDHW) {
- input = new NDArray(
- input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
- output = new NDArray(
- output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
- }
-
- if (isSameMode) // SAME
- ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
- dW);
-
-
- auto poolingMode = PoolingType::MAX_POOL;
- auto extraParam0 = 1;
-
- dnnl_memory_desc_t empty;
- dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
- dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
- dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
- dnnl::algorithm algorithm;
-
- mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
- extraParam0, true,
- bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
- algorithm,
- &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
- &user_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
-
- auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
- pool_dst_md, pool_strides, pool_kernel, pool_padding,
- pool_padding_r);
-
- auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
- dnnl::stream stream(engine);
- auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
- auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
- auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
-
- auto pool_src_memory = user_src_memory;
- if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
- pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
- reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
- }
-
- auto pool_dst_memory = user_dst_memory;
- if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
- pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
- }
-
- pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
- {DNNL_ARG_DST, pool_dst_memory}});
-
- if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
- reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
- }
-
- stream.wait();
-
-
- if (!isNCDHW) {
- delete input;
- delete output;
- }
+ mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max);
return Status::OK();
+
}
//////////////////////////////////////////////////////////////////////////
@@ -152,6 +81,7 @@ PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
+
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
@@ -162,127 +92,30 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
const int sD = INT_ARG(3); // strides depth
const int sH = INT_ARG(4); // strides height
const int sW = INT_ARG(5); // strides width
- int pD = INT_ARG(6); // paddings depth
- int pH = INT_ARG(7); // paddings height
- int pW = INT_ARG(8); // paddings width
+ int pD = INT_ARG(6); // paddings depth
+ int pH = INT_ARG(7); // paddings height
+ int pW = INT_ARG(8); // paddings width
const int dD = INT_ARG(9); // dilations depth
const int dH = INT_ARG(10); // dilations height
const int dW = INT_ARG(11); // dilations width
- const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
+ const int paddngMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
- int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
+ int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
- REQUIRE_TRUE(input->rankOf() == 5, 0,
- "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
- REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
- "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
+ REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf());
+ REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
- int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
+ int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
- ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
- indIOioC, indIOioD, indWiC, indWoC, indWkD);
+ ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
- std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
- {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
- std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
- {bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
- REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
- "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
- expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
- REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
- "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
- expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
+ std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
+ REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
- if (!isNCDHW) {
- input = new NDArray(input->permute(
- {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
- gradI = new NDArray(gradI->permute(
- {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
- gradO = new NDArray(gradO->permute(
- {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
- }
+ if(paddngMode) // SAME
+ ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
- if (isSameMode) // SAME
- ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
- dW);
-
-
- auto poolingMode = PoolingType::MAX_POOL;
- auto extraParam0 = 1;
-
- dnnl_memory_desc_t empty;
- dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
- dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
- dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
- dnnl::algorithm algorithm;
-
- mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
- extraParam0, true,
- bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
- algorithm,
- &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
- &user_diff_src_md, &user_dst_md,
- pool_strides, pool_kernel, pool_padding, pool_padding_r);
-
- // input is sometimes null, so we can't rely on pool_src_md being valid
- if (input->buffer() == nullptr) {
- pool_src_md = pool_diff_src_md;
- user_src_md = user_diff_src_md;
- }
- auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
-
- auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
- dnnl::stream stream(engine);
- auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
-
- auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
-
- auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
- auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
- auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
-
- auto poolB_src_memory = userB_src_memory;
- if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
- poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
- }
-
- auto poolB_dst_memory = userB_dst_memory;
- if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
- poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
- reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
- }
-
-
- auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
-
- auto pool_src_memory = user_src_memory;
- if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
- pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
- reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
- }
-
- auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
- auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
-
- pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
- {DNNL_ARG_DST, pool_dst_memory},
- {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
- pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
- {DNNL_ARG_WORKSPACE, pool_workspace_memory},
- {DNNL_ARG_DIFF_SRC, poolB_src_memory}});
-
-
- if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
- reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
- }
-
- stream.wait();
-
- if (!isNCDHW) {
- delete input;
- delete gradI;
- delete gradO;
- }
+ mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max);
return Status::OK();
}
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp
index 0b81de76d..02bba4300 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp
@@ -16,9 +16,11 @@
//
// @author saudet
+// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include
+#include
#include "mkldnnUtils.h"
using namespace dnnl;
@@ -26,6 +28,314 @@ using namespace dnnl;
namespace nd4j {
namespace mkldnnUtils {
+//////////////////////////////////////////////////////////////////////
+void poolingMKLDNN(const NDArray *input, NDArray *output,
+ const int kD, const int kH, const int kW,
+ const int sD, const int sH, const int sW,
+ const int pD, const int pH, const int pW,
+ const int isNCHW, const dnnl::algorithm mode) {
+
+ // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input
+ const int rank = input->rankOf();
+
+ int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;
+ dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims;
+ dnnl::memory::format_tag xzFrmat;
+
+ const auto type = dnnl::memory::data_type::f32;
+
+ if(rank == 4) { // 2d
+
+ ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
+
+ strides = { sH, sW };
+ kernel = { kH, kW };
+ padding = { pH, pW };
+ padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
+ xDims = {bS, iC, iH, iW};
+ zDims = {bS, oC, oH, oW};
+
+ xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
+ }
+ else { // 3d
+
+ ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
+
+ strides = { sD, sH, sW };
+ kernel = { kD, kH, kW };
+ padding = { pD, pH, pW };
+ padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
+ xDims = {bS, iC, iD, iH, iW};
+ zDims = {bS, oC, oD, oH, oW};
+
+ xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
+ }
+
+ // memory descriptors for arrays
+
+ // input
+ dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
+ if(rank == 5)
+ x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
+ }
+
+ // output
+ dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
+ if(output->ews() != 1 || output->ordering() != 'c') {
+ z_user_md.data.format_kind = dnnl_blocked; // overrides format
+ z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
+ z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 :-1);
+ z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1);
+ z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2);
+ if(rank == 5)
+ z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(isNCHW ? 4 : 3);
+ }
+
+ auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
+
+ // operation primitive description
+ dnnl::pooling_forward::desc op_desc(dnnl::prop_kind::forward_inference, mode, x_mkl_md, z_mkl_md, strides, kernel, padding, padding_r);
+ dnnl::pooling_forward::primitive_desc op_prim_desc(op_desc, engine);
+
+ // arguments (memory buffers) necessary for calculations
+ std::unordered_map args;
+
+ dnnl::stream stream(engine);
+
+ // provide memory buffers and check whether reorder is required
+
+ // input
+ auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
+ const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
+ auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
+ if (xReorder)
+ dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
+ args[DNNL_ARG_SRC] = x_mkl_mem;
+
+ // output
+ auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
+ const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
+ auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
+ args[DNNL_ARG_DST] = z_mkl_mem;
+
+ // run calculations
+ dnnl::pooling_forward(op_prim_desc).execute(stream, args);
+
+ // reorder outputs if necessary
+ if (zReorder)
+ dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
+
+ stream.wait();
+}
+
+//////////////////////////////////////////////////////////////////////
+void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
+ const int kD, const int kH, const int kW,
+ const int sD, const int sH, const int sW,
+ const int pD, const int pH, const int pW,
+ const int isNCHW, const dnnl::algorithm mode) {
+
+ // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input
+
+ const int rank = input->rankOf();
+
+ int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;
+ dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims;
+ dnnl::memory::format_tag xzFrmat;
+
+ const auto type = dnnl::memory::data_type::f32;
+
+ if(rank == 4) { // 2d
+
+ ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
+
+ strides = { sH, sW };
+ kernel = { kH, kW };
+ padding = { pH, pW };
+ padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
+ xDims = {bS, iC, iH, iW};
+ zDims = {bS, oC, oH, oW};
+
+ xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
+ }
+ else { // 3d
+
+ ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
+
+ strides = { sD, sH, sW };
+ kernel = { kD, kH, kW };
+ padding = { pD, pH, pW };
+ padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
+ xDims = {bS, iC, iD, iH, iW};
+ zDims = {bS, oC, oD, oH, oW};
+
+ xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
+ }
+
+ // memory descriptors for arrays
+
+ // input
+ dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ if(input->ews() != 1 || input->ordering() != 'c') {
+ x_user_md.data.format_kind = dnnl_blocked; // overrides format
+ x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
+ x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
+ x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
+ x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
+ if(rank == 5)
+ x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
+ }
+
+ // gradO
+ dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
+ if(gradO->ews() != 1 || gradO->ordering() != 'c') {
+ gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
+ gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 :-1);
+ gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1);
+ gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2);
+ if(rank == 5)
+ gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(isNCHW ? 4 : 3);
+ }
+
+ // gradI
+ dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
+ dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
+ if(gradI->ews() != 1 || gradI->ordering() != 'c') {
+ gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
+ gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
+ gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 :-1);
+ gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1);
+ gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2);
+ if(rank == 5)
+ gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(isNCHW ? 4 : 3);
+ }
+
+ auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
+ dnnl::stream stream(engine);
+
+ // forward primitive description
+ dnnl::pooling_forward::desc op_ff_desc(dnnl::prop_kind::forward, mode, x_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r);
+ dnnl::pooling_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
+
+ // backward primitive description
+ dnnl::pooling_backward::desc op_bp_desc(mode, gradI_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r);
+ dnnl::pooling_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
+
+ // arguments (memory buffers) necessary for calculations
+ std::unordered_map args;
+
+ // gradO
+ auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
+ const bool gradOReorder = op_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
+ auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
+ if (gradOReorder)
+ dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
+ args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
+
+ // gradI
+ auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
+ const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
+ auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
+ args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
+
+ if(mode == algorithm::pooling_max) {
+
+ // input
+ auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
+ const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
+ auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
+ if (xReorder)
+ dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
+ args[DNNL_ARG_SRC] = x_mkl_mem;
+
+ // z
+ auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
+ args[DNNL_ARG_DST] = z_mkl_mem;
+
+ // auxiliary memory allocation
+ auto workspace = dnnl::memory(op_ff_prim_desc.workspace_desc(), engine);
+ args[DNNL_ARG_WORKSPACE] = workspace;
+
+ // run forward calculations
+ dnnl::pooling_forward(op_ff_prim_desc).execute(stream, args);
+ }
+
+ // run backward calculations
+ dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args);
+
+
+ // reorder gradI if necessary
+ if (gradIReorder)
+ dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
+
+ stream.wait();
+}
+
+//////////////////////////////////////////////////////////////////////////
+void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
+ dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
+ dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
+ const Nd4jLong* shape = src->getShapeInfo();
+ long rank = shape[0];
+ long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
+ long dim2 = axis >= 2 ? 1 : 2;
+ long dim3 = axis >= 3 ? 2 : 3;
+ dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
+
+ auto type = dnnl::memory::data_type::f32;
+ auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
+ auto supposed_to_be_any_format = format; // doesn't work with "any"
+
+ if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
+ *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
+ *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
+ user_src_md->data.format_kind = dnnl_blocked;
+ user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
+ user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
+ user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
+ user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
+ }
+
+ if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
+ *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
+ *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
+ user_diff_src_md->data.format_kind = dnnl_blocked;
+ user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
+ user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
+ user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
+ user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
+ }
+
+ if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
+ *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
+ *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
+ user_dst_md->data.format_kind = dnnl_blocked;
+ user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
+ user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
+ user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
+ user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+dnnl::engine& getEngine(void *ptr) {
+ auto eng = reinterpret_cast(ptr);
+ return *eng;
+}
+
+
+/*
//////////////////////////////////////////////////////////////////////////
void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
@@ -307,104 +617,51 @@ void getMKLDNNMemoryDescConv3d(
}
};
-
-// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
-// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
-// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
-// const Nd4jLong* shape = src->getShapeInfo();
-// Nd4jLong rank = shape[0];
-// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
-// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
-// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
-// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
-
-// auto type = dnnl::memory::data_type::f32;
-// auto format = dnnl::memory::format_tag::nchw;
-// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
-
-// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
-// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
-// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
-// user_src_md->data.format_kind = dnnl_blocked; // overrides format
-// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
-// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
-// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
-// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
-// }
-
-// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
-// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
-// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
-// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
-// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
-// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
-// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
-// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
-// }
-
-// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
-// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
-// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
-// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
-// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
-// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
-// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
-// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
-// }
-// };
-
-//////////////////////////////////////////////////////////////////////////
-void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
- dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
- dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
+void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
+ dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
+ dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
const Nd4jLong* shape = src->getShapeInfo();
- long rank = shape[0];
- long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
- long dim2 = axis >= 2 ? 1 : 2;
- long dim3 = axis >= 3 ? 2 : 3;
- dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
+ Nd4jLong rank = shape[0];
+ Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
+ Nd4jLong dim2 = axis >= 2 ? 1 : 2;
+ Nd4jLong dim3 = axis >= 3 ? 2 : 3;
+ dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
auto type = dnnl::memory::data_type::f32;
- auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
- auto supposed_to_be_any_format = format; // doesn't work with "any"
+ auto format = dnnl::memory::format_tag::nchw;
+ auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
- if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
- *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
- *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
- user_src_md->data.format_kind = dnnl_blocked;
+ if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
+ *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
+ *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
+ user_src_md->data.format_kind = dnnl_blocked; // overrides format
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
}
- if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
- *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
- *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
- user_diff_src_md->data.format_kind = dnnl_blocked;
+ if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
+ *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
+ *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
+ user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
}
- if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
- *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
- *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
- user_dst_md->data.format_kind = dnnl_blocked;
+ if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
+ *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
+ *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
+ user_dst_md->data.format_kind = dnnl_blocked; // overrides format
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
}
-}
-
-//////////////////////////////////////////////////////////////////////////
-dnnl::engine& getEngine(void *ptr) {
- auto eng = reinterpret_cast(ptr);
- return *eng;
-}
-
+};
+*/
}
}
\ No newline at end of file
diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h
index b55103a02..c8b34a6c0 100644
--- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h
+++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h
@@ -16,6 +16,7 @@
//
// @author saudet
+// @author Yurii Shyrma (iuriish@yahoo.com)
//
#ifndef DEV_TESTS_MKLDNNUTILS_H
@@ -81,17 +82,27 @@ namespace nd4j{
DECLARE_PLATFORM(deconv3d_bp, ENGINE_CPU);
DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU);
-
+
DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU);
}
}
namespace mkldnnUtils {
+ void poolingMKLDNN(const NDArray *input, NDArray *output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
+
+ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
+
+ void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
+ dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
+ dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
+
+ dnnl::engine& getEngine(void *ptr);
+
/**
* Utility methods for MKLDNN
*/
- void getMKLDNNMemoryDescConv2d(
+/* void getMKLDNNMemoryDescConv2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
@@ -130,12 +141,7 @@ namespace nd4j{
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
-
- void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
- dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
- dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
-
- dnnl::engine& getEngine(void *ptr);
+*/
}
}
diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp
index 15524a901..795a7da4d 100644
--- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp
+++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp
@@ -2031,121 +2031,6 @@ TEST_F(DeclarableOpsTests1, Sum1) {
}
*/
-//////////////////////////////////////////////////////////////////////
-TEST_F(DeclarableOpsTests1, Avgpool2d_test1) {
-
- auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW});
- auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW});
- // auto z('c',{bS,iD,oH,oW});
-
- auto variableSpace = new VariableSpace();
- variableSpace->putVariable(-1, x);
- // variableSpace->putVariable(1, &z);
-
- auto block = new Context(1, variableSpace, false);
- block->fillInputs({-1});
- std::vector* argI = block->getIArguments();
- *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
-
- nd4j::ops::avgpool2d pooling;
- Nd4jStatus status = pooling.execute(block);
- ASSERT_EQ(ND4J_STATUS_OK, status);
-
- auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
- ASSERT_TRUE(exp.isSameShape(result));
-
-
- delete variableSpace;
- delete block;
-}
-
-//////////////////////////////////////////////////////////////////////
-TEST_F(DeclarableOpsTests1, Avgpool2d_test2) {
- const int bS = 2;
- const int iD = 1;
- const int iH = 28;
- const int iW = 28;
- const int kH = 5;
- const int kW = 5;
- const int sH = 1;
- const int sW = 1;
- const int pH = 0;
- const int pW = 0;
- const int dH = 1;
- const int dW = 1;
- const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
- const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
-
-
- auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW});
- auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW});
- // auto z('c',{bS,iD,oH,oW});
-
- auto variableSpace = new VariableSpace();
- variableSpace->putVariable(-1, x);
- // variableSpace->putVariable(1, &z);
-
- auto block = new Context(1, variableSpace, false);
- block->fillInputs({-1});
- std::vector* argI = block->getIArguments();
- *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
-
- nd4j::ops::avgpool2d pooling;
- Nd4jStatus status = pooling.execute(block);
- ASSERT_EQ(ND4J_STATUS_OK, status);
-
- auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
- // result->printShapeInfo();
- ASSERT_TRUE(exp.isSameShape(result));
-
- delete variableSpace;
- delete block;
-}
-
-//////////////////////////////////////////////////////////////////////
-TEST_F(DeclarableOpsTests1, Avgpool2d_test3) {
- const int bS = 2;
- const int iD = 1;
- const int iH = 28;
- const int iW = 28;
- const int kH = 5;
- const int kW = 5;
- const int sH = 1;
- const int sW = 1;
- const int pH = 0;
- const int pW = 0;
- const int dH = 1;
- const int dW = 1;
- const int oH = (int) nd4j::math::nd4j_ceil(iH * 1.f / sH);
- const int oW = (int) nd4j::math::nd4j_ceil(iW * 1.f / sW);
-
-
- auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW});
- auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW});
- // auto z('c',{bS,iD,oH,oW});
-
- auto variableSpace = new VariableSpace();
- variableSpace->putVariable(-1, x);
- // variableSpace->putVariable(1, &z);
-
- auto block = new Context(1, variableSpace, false);
- block->fillInputs({-1});
- std::vector* argI = block->getIArguments();
- *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
-
- nd4j::ops::avgpool2d pooling;
- Nd4jStatus status = pooling.execute(block);
- ASSERT_EQ(ND4J_STATUS_OK, status);
-
- auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
- // result->printShapeInfo();
- ASSERT_TRUE(exp.isSameShape(result));
-
- delete variableSpace;
- delete block;
-}
-
-
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Pnormpool2d1) {
diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp
index de4bdc31b..465703768 100644
--- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp
+++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp
@@ -1667,6 +1667,241 @@ TEST_F(DeclarableOpsTests11, Solve_Test_4) {
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests11, Solve_Test_4_1) {
+
+ auto a = NDArrayFactory::create('c', {2, 2, 2}, {
+ 0.7788f, 0.8012f, 0.7244f, 0.2309f,
+ 0.7271f, 0.1804f, 0.5056f, 0.8925f
+ });
+
+ auto b = NDArrayFactory::create('c', {2, 2, 2}, {
+ 0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f
+ });
+
+ auto exp = NDArrayFactory::create('c', {2, 2, 2}, {
+ 1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f,
+ 0.4400987f, 0.2766527f, 0.6394467f, 0.79696566f
+ });
+
+ nd4j::ops::solve op;
+
+ auto res = op.evaluate({&a, &b}, {true});
+ ASSERT_EQ(res->status(), ND4J_STATUS_OK);
+ auto z = res->at(0);
+
+// z->printBuffer("4 Solve 4x4");
+// exp.printBuffer("4 Expec 4x4");
+
+ ASSERT_TRUE(exp.equalsTo(z));
+ delete res;
+}
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests11, Solve_Test_4_2) {
+
+ auto a = NDArrayFactory::create('c', {3, 3}, {
+ 0.7788f, 0.8012f, 0.7244f,
+ 0.2309f, 0.7271f, 0.1804f,
+ 0.5056f, 0.8925f, 0.5461f
+ });
+
+ auto b = NDArrayFactory::create('c', {3, 3}, {
+ 0.7717f, 0.9281f, 0.9846f,
+ 0.4838f, 0.6433f, 0.6041f,
+ 0.6501f, 0.7612f, 0.7605f
+ });
+
+ auto exp = NDArrayFactory::create('c', {3, 3}, {
+ 0.99088347f, 1.1917052f, 1.2642528f,
+ 0.35071516f, 0.50630623f, 0.42935497f,
+ -0.30013534f, -0.53690606f, -0.47959247f
+ });
+
+ nd4j::ops::triangular_solve op;
+
+ auto res = op.evaluate({&a, &b}, {true, false});
+ ASSERT_EQ(res->status(), ND4J_STATUS_OK);
+ auto z = res->at(0);
+
+// z->printBuffer("4_2 Triangular_Solve 3x3");
+// exp.printBuffer("4_2 Triangular_Expec 3x3");
+
+ ASSERT_TRUE(exp.equalsTo(z));
+ delete res;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests11, Solve_Test_4_3) {
+
+ auto a = NDArrayFactory::create('c', {3, 3}, {
+ 0.7788f, 0.8012f, 0.7244f,
+ 0.2309f, 0.7271f, 0.1804f,
+ 0.5056f, 0.8925f, 0.5461f
+ });
+
+ auto b = NDArrayFactory::create('c', {3, 3}, {
+ 0.7717f, 0.9281f, 0.9846f,
+ 0.4838f, 0.6433f, 0.6041f,
+ 0.6501f, 0.7612f, 0.7605f
+ });
+
+ auto exp = NDArrayFactory::create('c', {3, 3}, {
+ 0.45400196f, 0.53174824f, 0.62064564f,
+ -0.79585856f, -0.82621557f, -0.87855506f,
+ 1.1904413f, 1.3938838f, 1.3926021f
+ });
+
+ nd4j::ops::triangular_solve op;
+
+ auto res = op.evaluate({&a, &b}, {true, true});
+ ASSERT_EQ(res->status(), ND4J_STATUS_OK);
+ auto z = res->at(0);
+
+// z->printBuffer("4_3 Triangular_Solve 3x3");
+// exp.printBuffer("4_3 Triangular_Expec 3x3");
+
+ ASSERT_TRUE(exp.equalsTo(z));
+ delete res;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests11, Solve_Test_4_4) {
+
+ auto a = NDArrayFactory::create('c', {3, 3}, {
+ 0.7788f, 0.8012f, 0.7244f,
+ 0.2309f, 0.7271f, 0.1804f,
+ 0.5056f, 0.8925f, 0.5461f
+ });
+
+ auto b = NDArrayFactory::create('c', {3, 3}, {
+ 0.7717f, 0.9281f, 0.9846f,
+ 0.4838f, 0.6433f, 0.6041f,
+ 0.6501f, 0.7612f, 0.7605f
+ });
+
+ auto exp = NDArrayFactory::create('c', {3, 3}, {
+ 0.8959121f, 1.6109066f, 1.7501404f,
+ 0.49000582f, 0.66842675f, 0.5577021f,
+ -0.4398522f, -1.1899745f, -1.1392052f
+ });
+
+ nd4j::ops::solve op;
+
+ auto res = op.evaluate({&a, &b}, {false});
+ ASSERT_EQ(res->status(), ND4J_STATUS_OK);
+ auto z = res->at(0);
+
+// z->printBuffer("4_4 Solve 3x3");
+// exp.printBuffer("4_4 Expec 3x3");
+
+ ASSERT_TRUE(exp.equalsTo(z));
+ delete res;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests11, Solve_Test_4_5) {
+
+ auto a = NDArrayFactory::create('c', {3, 3}, {
+ 0.7788f, 0.8012f, 0.7244f,
+ 0.2309f, 0.7271f, 0.1804f,
+ 0.5056f, 0.8925f, 0.5461f
+ });
+
+ auto b = NDArrayFactory::create('c', {3, 3}, {
+ 0.7717f, 0.9281f, 0.9846f,
+ 0.4838f, 0.6433f, 0.6041f,
+ 0.6501f, 0.7612f, 0.7605f
+ });
+
+ auto exp = NDArrayFactory::create('c', {3, 3}, {
+ 1.5504692f, 1.8953944f, 2.2765768f,
+ 0.03399149f, 0.2883001f, 0.5377323f,
+ -0.8774802f, -1.2155888f, -1.8049058f
+ });
+
+ nd4j::ops::solve op;
+
+ auto res = op.evaluate({&a, &b}, {true, true});
+ ASSERT_EQ(res->status(), ND4J_STATUS_OK);
+ auto z = res->at(0);
+
+// z->printBuffer("4_5 Solve 3x3");
+// exp.printBuffer("4_5 Expec 3x3");
+
+ ASSERT_TRUE(exp.equalsTo(z));
+ delete res;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests11, Solve_Test_4_6) {
+
+ auto a = NDArrayFactory::create('c', {3, 3}, {
+ 0.7788f, 0.8012f, 0.7244f,
+ 0.2309f, 0.7271f, 0.1804f,
+ 0.5056f, 0.8925f, 0.5461f
+ });
+
+ auto b = NDArrayFactory::create('c', {3, 3}, {
+ 0.7717f, 0.9281f, 0.9846f,
+ 0.4838f, 0.6433f, 0.6041f,
+ 0.6501f, 0.7612f, 0.7605f
+ });
+
+ auto exp = NDArrayFactory::create('c', {3, 3}, {
+ 0.99088347f, 1.1917052f, 1.2642528f,
+ -0.426483f, -0.42840624f, -0.5622601f,
+ 0.01692283f, -0.04538865f, -0.09868701f
+ });
+
+ nd4j::ops::triangular_solve op;
+
+ auto res = op.evaluate({&a, &b}, {false, true});
+ ASSERT_EQ(res->status(), ND4J_STATUS_OK);
+ auto z = res->at(0);
+
+ z->printBuffer("4_6 Solve 3x3");
+ exp.printBuffer("4_6 Expec 3x3");
+
+ ASSERT_TRUE(exp.equalsTo(z));
+ delete res;
+}
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests11, Solve_Test_4_7) {
+
+ auto a = NDArrayFactory::create('c', {3, 3}, {
+// 0.7788f, 0.2309f, 0.5056f,
+// 0.8012f, 0.7271f, 0.8925f,
+// 0.7244f, 0.1804f, 0.5461f
+
+ 0.7788f, 0.2309f, 0.5056f,
+ 0.8012f, 0.7271f, 0.8925f,
+ 0.7244f, 0.1804f, 0.5461f
+ });
+
+ auto b = NDArrayFactory::create('c', {3, 3}, {
+ 0.7717f, 0.9281f, 0.9846f,
+ 0.4838f, 0.6433f, 0.6041f,
+ 0.6501f, 0.7612f, 0.7605f
+ });
+
+ auto exp = NDArrayFactory::create('c', {3, 3}, {
+ 0.99088347f, 1.1917052f, 1.2642528f,
+ -0.426483f, -0.42840624f, -0.5622601f,
+ 0.01692283f, -0.04538865f, -0.09868701f
+ });
+
+ nd4j::ops::triangular_solve op;
+
+ auto res = op.evaluate({&a, &b}, {true, false});
+ ASSERT_EQ(res->status(), ND4J_STATUS_OK);
+ auto z = res->at(0);
+
+ z->printBuffer("4_7 Solve 3x3");
+ exp.printBuffer("4_7 Expec 3x3");
+
+ ASSERT_TRUE(exp.equalsTo(z));
+ delete res;
+}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_5) {
diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp
index 1e085d46c..f04d24395 100644
--- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp
+++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp
@@ -360,7 +360,6 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) {
917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5,
1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5});
input.linspace(1.);
- input.syncToDevice();
nd4j::ops::avgpool2d op;
auto results = op.evaluate({&input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat});
@@ -377,6 +376,160 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) {
delete results;
}
+//////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests4, avgpool2d_13) {
+
+ const int bS = 2; // batch size
+ const int iD = 1; // input depth (number of picture channels, for example rgb=3)
+ const int iH = 28; // picture height in pixels
+ const int iW = 28; // picture width in pixels
+ const int kH = 5; // kernel height in pixels
+ const int kW = 5; // kernel width in pixels
+ const int sH = 1; // stride step in horizontal direction
+ const int sW = 1; // stride step in vertical direction
+ const int pH = 0; // padding height
+ const int pW = 0; // padding width
+ const int dH = 2; // dilation height
+ const int dW = 2; // dilation width
+ const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
+ const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
+
+ auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW});
+ auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW});
+ // auto z('c',{bS,iD,oH,oW});
+
+ auto variableSpace = new VariableSpace();
+ variableSpace->putVariable(-1, x);
+ // variableSpace->putVariable(1, &z);
+
+ auto block = new Context(1, variableSpace, false);
+ block->fillInputs({-1});
+ std::vector* argI = block->getIArguments();
+ *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
+
+ nd4j::ops::avgpool2d pooling;
+ Nd4jStatus status = pooling.execute(block);
+ ASSERT_EQ(ND4J_STATUS_OK, status);
+
+ auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
+ ASSERT_TRUE(exp.isSameShape(result));
+
+
+ delete variableSpace;
+ delete block;
+}
+
+//////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests4, avgpool2d_14) {
+ const int bS = 2;
+ const int iD = 1;
+ const int iH = 28;
+ const int iW = 28;
+ const int kH = 5;
+ const int kW = 5;
+ const int sH = 1;
+ const int sW = 1;
+ const int pH = 0;
+ const int pW = 0;
+ const int dH = 1;
+ const int dW = 1;
+ const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
+ const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
+
+
+ auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW});
+ auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW});
+ // auto z('c',{bS,iD,oH,oW});
+
+ auto variableSpace = new VariableSpace();
+ variableSpace->putVariable(-1, x);
+ // variableSpace->putVariable(1, &z);
+
+ auto block = new Context(1, variableSpace, false);
+ block->fillInputs({-1});
+ std::vector* argI = block->getIArguments();
+ *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
+
+ nd4j::ops::avgpool2d pooling;
+ Nd4jStatus status = pooling.execute(block);
+ ASSERT_EQ(ND4J_STATUS_OK, status);
+
+ auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
+ // result->printShapeInfo();
+ ASSERT_TRUE(exp.isSameShape(result));
+
+ delete variableSpace;
+ delete block;
+}
+
+//////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests4, Avgpool2d_test15) {
+ const int bS = 2;
+ const int iD = 1;
+ const int iH = 28;
+ const int iW = 28;
+ const int kH = 5;
+ const int kW = 5;
+ const int sH = 1;
+ const int sW = 1;
+ const int pH = 0;
+ const int pW = 0;
+ const int dH = 1;
+ const int dW = 1;
+ const int oH = (int) nd4j::math::nd4j_ceil(iH * 1.f / sH);
+ const int oW = (int) nd4j::math::nd4j_ceil(iW * 1.f / sW);
+
+
+ auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW});
+ auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW});
+ // auto z('c',{bS,iD,oH,oW});
+
+ auto variableSpace = new VariableSpace();
+ variableSpace->putVariable(-1, x);
+ // variableSpace->putVariable(1, &z);
+
+ auto block = new Context(1, variableSpace, false);
+ block->fillInputs({-1});
+ std::vector* argI = block->getIArguments();
+ *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
+
+ nd4j::ops::avgpool2d pooling;
+ Nd4jStatus status = pooling.execute(block);
+ ASSERT_EQ(ND4J_STATUS_OK, status);
+
+ auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
+ // result->printShapeInfo();
+ ASSERT_TRUE(exp.isSameShape(result));
+
+ delete variableSpace;
+ delete block;
+}
+
+//////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests4, avgpool2d_16) {
+
+ int bS=2, iH=4,iW=4, iC=2, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1;
+ int oH=2,oW=2;
+ int paddingMode = 1; // 1-SAME, 0-VALID
+ int dataFormat = 1; // 1-NHWC, 0-NDHW
+
+ NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
+ NDArray output('f', {bS, oH, oW, iC}, nd4j::DataType::FLOAT32);
+ NDArray expected('c', {bS, oH, oW, iC}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, nd4j::DataType::FLOAT32);
+
+ input.linspace(1.);
+
+ nd4j::ops::avgpool2d op;
+ auto status = op.execute({&input}, {&output}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}, {});
+
+ ASSERT_EQ(Status::OK(), status);
+
+ // output.printBuffer();
+ //expected.printIndexedBuffer("expected");
+
+ ASSERT_TRUE(expected.equalsTo(output));
+}
+
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, biasadd_1) {
auto x = NDArrayFactory::create('c', {2, 3, 3, 2});
diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp
index 39761ecb3..0a6f8e5e8 100644
--- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp
+++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp
@@ -758,7 +758,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) {
TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) {
auto input = NDArrayFactory::create('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
- auto exp = NDArrayFactory::create('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ auto exp = NDArrayFactory::create('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -802,6 +802,66 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) {
delete result;
}
+TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) {
+ auto input = NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8});
+ auto exp = NDArrayFactory::create('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+
+ nd4j::ops::sequence_mask op;
+ auto result = op.evaluate({&input}, {nd4j::DataType::INT32});
+ ASSERT_EQ(Status::OK(), result->status());
+
+ auto z = result->at(0);
+// z->printBuffer("Output");
+// z->printShapeInfo("Shape");
+ ASSERT_TRUE(exp.isSameShape(z));
+ ASSERT_TRUE(exp.equalsTo(z));
+
+ delete result;
+}
+
+TEST_F(DeclarableOpsTests7, Test_SequenceMask_4) {
+ auto input = NDArrayFactory::create({1, 3, 2});
+ auto maxLen = NDArrayFactory::create(5);
+ auto exp = NDArrayFactory::create('c', {3,5}, {
+ 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f
+ });
+
+ nd4j::ops::sequence_mask op;
+ auto result = op.evaluate({&input, &maxLen}, {nd4j::DataType::FLOAT32});
+ ASSERT_EQ(Status::OK(), result->status());
+
+ auto z = result->at(0);
+// z->printBuffer("Output");
+// z->printShapeInfo("Shape");
+ ASSERT_TRUE(exp.isSameShape(z));
+ ASSERT_TRUE(exp.equalsTo(z));
+
+ delete result;
+}
+
+TEST_F(DeclarableOpsTests7, Test_SequenceMask_5) {
+ auto input = NDArrayFactory::create({1, 3, 2});
+ auto exp = NDArrayFactory::create('c', {3,5}, {
+ 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f
+ });
+
+ nd4j::ops::sequence_mask op;
+ auto result = op.evaluate({&input}, {5, (int)nd4j::DataType::FLOAT32});
+ ASSERT_EQ(Status::OK(), result->status());
+
+ auto z = result->at(0);
+// z->printBuffer("Output");
+// z->printShapeInfo("Shape");
+ ASSERT_TRUE(exp.isSameShape(z));
+ ASSERT_TRUE(exp.equalsTo(z));
+
+ delete result;
+}
+
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestSegmentMax_1) {
auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.});
diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp
index 970c119ca..4d7a0f783 100644
--- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp
+++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp
@@ -422,50 +422,38 @@ TEST_F(PlaygroundTests, my) {
delete variableSpace;
}
-
-#include
-
TEST_F(PlaygroundTests, my) {
- const int N = 10000;
- const Nd4jLong dim0(128), dim1(128), dim2(128);
+ int N = 100;
+ int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
+ int oH=128,oW=128;
- NDArray input('c', {dim0,dim1,dim2}, nd4j::DataType::DOUBLE);
- NDArray mean('c', {dim1}, nd4j::DataType::DOUBLE);
- NDArray variance('c', {dim1}, nd4j::DataType::DOUBLE);
- NDArray gamma('c', {dim1}, nd4j::DataType::DOUBLE);
- NDArray beta ('c', {dim1}, nd4j::DataType::DOUBLE);
+ int paddingMode = 1; // 1-SAME, 0-VALID;
+ int dataFormat = 1; // 1-NHWC, 0-NCHW
- NDArray output('c', {dim0,dim1,dim2}, nd4j::DataType::DOUBLE);
+ // NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
+ // NDArray output('c', {bS, oC, oH, oW}, nd4j::DataType::FLOAT32);
+ NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
+ NDArray output('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32);
+ // NDArray weights('c', {kH, kW, iC, oC}, nd4j::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
+ NDArray weights('c', {oC, iC, kH, kW}, nd4j::DataType::FLOAT32);
+ NDArray bias('c', {oC}, nd4j::DataType::FLOAT32);
- input.linspace(-100, 0.1);
- mean.linspace(-50, 0.15);
- variance.linspace(-5, 0.2);
- gamma = 1.5;
- beta = -2.5;
+ input = 5.;
+ weights = 3.;
+ bias = 1.;
- // warm up
- ops::helpers::batchnorm(&input, &mean, &variance, &gamma, &beta, &output, {1}, 1e-5);
+ nd4j::ops::conv2d op;
+ auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto timeStart = std::chrono::system_clock::now();
for (int i = 0; i < N; ++i)
- ops::helpers::batchnorm(&input, &mean, &variance, &gamma, &beta, &output, {1}, 1e-5);
-
+ err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto timeEnd = std::chrono::system_clock::now();
- auto time = std::chrono::duration_cast ((timeEnd - timeStart)/N).count();
-
- printf("time: %li \n", time);
+ auto time = std::chrono::duration_cast ((timeEnd - timeStart) / N).count();
+ printf("time: %i \n", time);
}
*/
-
-
-
-
-
-
-
-
-
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java
index 78b12e7fc..12e27e1c2 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java
@@ -780,7 +780,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
throw new IllegalArgumentException("Unable to create array of length " + length);
float[] ret = new float[(int) length];
for (int i = 0; i < length; i++)
- ret[i] = getFloat(i);
+ ret[i] = getFloatUnsynced(i);
return ret;
}
@@ -790,7 +790,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
throw new IllegalArgumentException("Unable to create array of length " + length);
double[] ret = new double[(int) length];
for (int i = 0; i < length; i++)
- ret[i] = getDouble(i);
+ ret[i] = getDoubleUnsynced(i);
return ret;
}
@@ -800,7 +800,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
throw new IllegalArgumentException("Unable to create array of length " + length);
int[] ret = new int[(int) length];
for (int i = 0; i < length; i++)
- ret[i] = getInt(i);
+ ret[i] = getIntUnsynced(i);
return ret;
}
@@ -810,7 +810,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
throw new IllegalArgumentException("Unable to create array of length " + length);
long[] ret = new long[(int) length];
for (int i = 0; i < length; i++)
- ret[i] = getLong(i);
+ ret[i] = getLongUnsynced(i);
return ret;
}
@@ -1662,6 +1662,11 @@ public abstract class BaseDataBuffer implements DataBuffer {
}
+ protected abstract double getDoubleUnsynced(long index);
+ protected abstract float getFloatUnsynced(long index);
+ protected abstract long getLongUnsynced(long index);
+ protected abstract int getIntUnsynced(long index);
+
@Override
public void write(DataOutputStream out) throws IOException {
out.writeUTF(allocationMode.name());
@@ -1670,43 +1675,43 @@ public abstract class BaseDataBuffer implements DataBuffer {
switch (dataType()) {
case DOUBLE:
for (long i = 0; i < length(); i++)
- out.writeDouble(getDouble(i));
+ out.writeDouble(getDoubleUnsynced(i));
break;
case UINT64:
case LONG:
for (long i = 0; i < length(); i++)
- out.writeLong(getLong(i));
+ out.writeLong(getLongUnsynced(i));
break;
case UINT32:
case INT:
for (long i = 0; i < length(); i++)
- out.writeInt(getInt(i));
+ out.writeInt(getIntUnsynced(i));
break;
case UINT16:
case SHORT:
for (long i = 0; i < length(); i++)
- out.writeShort((short) getInt(i));
+ out.writeShort((short) getIntUnsynced(i));
break;
case UBYTE:
case BYTE:
for (long i = 0; i < length(); i++)
- out.writeByte((byte) getInt(i));
+ out.writeByte((byte) getIntUnsynced(i));
break;
case BOOL:
for (long i = 0; i < length(); i++)
- out.writeByte(getInt(i) == 0 ? (byte) 0 : (byte) 1);
+ out.writeByte(getIntUnsynced(i) == 0 ? (byte) 0 : (byte) 1);
break;
case BFLOAT16:
for (long i = 0; i < length(); i++)
- out.writeShort((short) Bfloat16Indexer.fromFloat(getFloat(i)));
+ out.writeShort((short) Bfloat16Indexer.fromFloat(getFloatUnsynced(i)));
break;
case HALF:
for (long i = 0; i < length(); i++)
- out.writeShort((short) HalfIndexer.fromFloat(getFloat(i)));
+ out.writeShort((short) HalfIndexer.fromFloat(getFloatUnsynced(i)));
break;
case FLOAT:
for (long i = 0; i < length(); i++)
- out.writeFloat(getFloat(i));
+ out.writeFloat(getFloatUnsynced(i));
break;
}
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java
index dffb93a7b..ded5bc938 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java
@@ -43,7 +43,7 @@ public class DeallocatorService {
private Map referenceMap = new ConcurrentHashMap<>();
private List>> deviceMap = new ArrayList<>();
- private AtomicLong counter = new AtomicLong(0);
+ private final transient AtomicLong counter = new AtomicLong(0);
public DeallocatorService() {
// we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java
index 4a56e2a88..0139a9db5 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java
@@ -153,4 +153,10 @@ public abstract class BaseOpContext implements OpContext {
for (int e = 0; e < arrays.length; e++)
setOutputArray(e, arrays[e]);
}
+
+ @Override
+ public void purge() {
+ fastpath_in.clear();
+ fastpath_out.clear();
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java
index 4063746b3..62a4906a7 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java
@@ -162,4 +162,9 @@ public interface OpContext extends AutoCloseable {
* @param mode
*/
void setExecutionMode(ExecutionMode mode);
+
+ /**
+ * This method removes all in/out arrays from this OpContext
+ */
+ void purge();
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java
index 0c822ce0a..f1c9ed6d9 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java
@@ -210,4 +210,24 @@ public class CompressedDataBuffer extends BaseDataBuffer {
public DataBuffer reallocate(long length) {
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
}
+
+ @Override
+ protected double getDoubleUnsynced(long index) {
+ return super.getDouble(index);
+ }
+
+ @Override
+ protected float getFloatUnsynced(long index) {
+ return super.getFloat(index);
+ }
+
+ @Override
+ protected long getLongUnsynced(long index) {
+ return super.getLong(index);
+ }
+
+ @Override
+ protected int getIntUnsynced(long index) {
+ return super.getInt(index);
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java
index d284974eb..1a01bf278 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java
@@ -1161,6 +1161,7 @@ public interface NativeOps {
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride);
+ void ctxPurge(OpaqueContext ptr);
void deleteGraphContext(OpaqueContext ptr);
OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed);
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml
index d98c7a6d1..36f25d636 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml
@@ -60,7 +60,7 @@
Maximum heap size was set to 6g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
-->
- -Ddtype=float -Xmx8g
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java
index aaccf9a34..46964c8f4 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java
@@ -469,8 +469,8 @@ public class AtomicAllocator implements Allocator {
memoryHandler.purgeZeroObject(bucketId, objectId, point, copyback);
- getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent());
- getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent());
+ //getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent());
+ //getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent());
}
/**
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java
index a7412fd76..7cc3e6838 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java
@@ -26,11 +26,11 @@ import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
/**
+ *
* @author raver119@gmail.com
*/
+@Deprecated
public class EventsProvider {
- //private static final EventsProvider INSTANCE = new EventsProvider();
-
private List> queue = new ArrayList<>();
private AtomicLong newCounter = new AtomicLong(0);
private AtomicLong cacheCounter = new AtomicLong(0);
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java
index f5f68ea76..030ccad30 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java
@@ -72,12 +72,7 @@ public class SynchronousFlowController implements FlowController {
@Override
public void waitTillFinished(AllocationPoint point) {
- /*CudaContext context = point.getCurrentContext(); //(CudaContext) allocator.getDeviceContext().getContext();
- if (context == null)
- context = (CudaContext) allocator.getDeviceContext().getContext();
- context.syncOldStream();
- */
-
+ // this should be always null, since synchronization happens in C++ now
if (point.getLastWriteEvent() != null) {
point.getLastWriteEvent().synchronize();
}
@@ -181,8 +176,8 @@ public class SynchronousFlowController implements FlowController {
@Override
public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint... operands) {
-
-
+ // this method is irrelevant now, everything happens in C++ now
+ /*
eventsProvider.storeEvent(result.getLastWriteEvent());
result.setLastWriteEvent(eventsProvider.getEvent());
result.getLastWriteEvent().register(context.getOldStream());
@@ -194,6 +189,7 @@ public class SynchronousFlowController implements FlowController {
operand.getLastReadEvent().register(context.getOldStream());
}
// context.syncOldStream();
+ */
}
@Override
@@ -204,9 +200,6 @@ public class SynchronousFlowController implements FlowController {
val pointOperand = allocator.getAllocationPoint(operand);
pointOperand.tickDeviceWrite();
- eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
- pointOperand.setLastWriteEvent(eventsProvider.getEvent());
- pointOperand.getLastWriteEvent().register(context.getOldStream());
}
}
@@ -216,18 +209,13 @@ public class SynchronousFlowController implements FlowController {
val point = allocator.getAllocationPoint(result);
point.tickDeviceWrite();
- eventsProvider.storeEvent(point.getLastWriteEvent());
- point.setLastWriteEvent(eventsProvider.getEvent());
- point.getLastWriteEvent().register(context.getOldStream());
for (INDArray operand : operands) {
if (operand == null || operand.isEmpty())
continue;
val pointOperand = allocator.getAllocationPoint(operand);
- eventsProvider.storeEvent(pointOperand.getLastReadEvent());
- pointOperand.setLastReadEvent(eventsProvider.getEvent());
- pointOperand.getLastReadEvent().register(context.getOldStream());
+ pointOperand.tickDeviceRead();
}
}
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java
index 02b857f7f..2f1cab334 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java
@@ -307,7 +307,6 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
if (allocationPoint.getHostPointer() == null) {
val location = allocationPoint.getAllocationStatus();
if (parentWorkspace == null) {
- //log.info("dbAllocate step");
// let cpp allocate primary buffer
NativeOpsHolder.getInstance().getDeviceNativeOps().dbAllocatePrimaryBuffer(ptrDataBuffer);
} else {
@@ -1288,6 +1287,26 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
@Override
public void destroy() {}
+ @Override
+ protected double getDoubleUnsynced(long index) {
+ return super.getDouble(index);
+ }
+
+ @Override
+ protected float getFloatUnsynced(long index) {
+ return super.getFloat(index);
+ }
+
+ @Override
+ protected long getLongUnsynced(long index) {
+ return super.getLong(index);
+ }
+
+ @Override
+ protected int getIntUnsynced(long index) {
+ return super.getInt(index);
+ }
+
@Override
public void write(DataOutputStream out) throws IOException {
lazyAllocateHostPointer();
@@ -1511,6 +1530,13 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
return super.asInt();
}
+ @Override
+ public long[] asLong() {
+ lazyAllocateHostPointer();
+ allocator.synchronizeHostData(this);
+ return super.asLong();
+ }
+
@Override
public ByteBuffer asNio() {
lazyAllocateHostPointer();
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java
index 01127e891..5e26b3ea3 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java
@@ -23,6 +23,8 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
+import org.nd4j.linalg.api.memory.Deallocatable;
+import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOpContext;
import org.nd4j.linalg.api.ops.ExecutionMode;
@@ -40,14 +42,19 @@ import org.nd4j.nativeblas.OpaqueRandomGenerator;
* CUDA wrapper for op Context
* @author raver119@gmail.com
*/
-public class CudaOpContext extends BaseOpContext implements OpContext {
+public class CudaOpContext extends BaseOpContext implements OpContext, Deallocatable {
// we might want to have configurable
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
private OpaqueContext context = nativeOps.createGraphContext(1);
+ private final transient long id = Nd4j.getDeallocatorService().nextValue();
+
+ public CudaOpContext() {
+ Nd4j.getDeallocatorService().pickObject(this);
+ }
@Override
public void close() {
- nativeOps.deleteGraphContext(context);
+ // no-op
}
@Override
@@ -143,4 +150,25 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
super.setExecutionMode(mode);
nativeOps.ctxSetExecutionMode(context, mode.ordinal());
}
+
+ @Override
+ public void purge() {
+ super.purge();
+ nativeOps.ctxPurge(context);
+ }
+
+ @Override
+ public String getUniqueId() {
+ return new String("CTX_" + id);
+ }
+
+ @Override
+ public Deallocator deallocator() {
+ return new CudaOpContextDeallocator(this);
+ }
+
+ @Override
+ public int targetDevice() {
+ return 0;
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContextDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContextDeallocator.java
new file mode 100644
index 000000000..62b5e4a00
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContextDeallocator.java
@@ -0,0 +1,34 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.nd4j.linalg.jcublas.ops.executioner;
+
+import org.nd4j.linalg.api.memory.Deallocator;
+import org.nd4j.nativeblas.NativeOpsHolder;
+import org.nd4j.nativeblas.OpaqueContext;
+
+public class CudaOpContextDeallocator implements Deallocator {
+ private transient final OpaqueContext context;
+
+ public CudaOpContextDeallocator(CudaOpContext ctx) {
+ context = (OpaqueContext) ctx.contextPointer();
+ }
+
+ @Override
+ public void deallocate() {
+ NativeOpsHolder.getInstance().getDeviceNativeOps().deleteGraphContext(context);
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
index f85ae9cf1..e7ddcda11 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
@@ -3090,6 +3090,7 @@ public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride);
public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
+public native void ctxPurge(OpaqueContext ptr);
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
@@ -6453,6 +6454,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs);
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs);
+ /**
+ * This method purges fastpath in/out contents and releases all the handles.
+ *
+ * PLEASE NOTE: I/T/B/D args will stay intact
+ */
+ public native void clearFastPath();
+
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java
index 71583638a..a5ddc7aef 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java
@@ -43,7 +43,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
protected transient OpaqueDataBuffer ptrDataBuffer;
- private final long instanceId = Nd4j.getDeallocatorService().nextValue();
+ private transient final long instanceId = Nd4j.getDeallocatorService().nextValue();
protected BaseCpuDataBuffer() {
@@ -52,7 +52,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
@Override
public String getUniqueId() {
- return "BCDB_" + instanceId;
+ return new String("BCDB_" + instanceId);
}
@Override
@@ -208,6 +208,26 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype));
}
+ @Override
+ protected double getDoubleUnsynced(long index) {
+ return super.getDouble(index);
+ }
+
+ @Override
+ protected float getFloatUnsynced(long index) {
+ return super.getFloat(index);
+ }
+
+ @Override
+ protected long getLongUnsynced(long index) {
+ return super.getLong(index);
+ }
+
+ @Override
+ protected int getIntUnsynced(long index) {
+ return super.getInt(index);
+ }
+
@Override
public void pointerIndexerByCurrentType(DataType currentType) {
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java
index 3b8a46fa6..e808ebaa3 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java
@@ -28,7 +28,7 @@ import org.nd4j.nativeblas.OpaqueDataBuffer;
*/
@Slf4j
public class CpuDeallocator implements Deallocator {
- private OpaqueDataBuffer opaqueDataBuffer;
+ private final transient OpaqueDataBuffer opaqueDataBuffer;
public CpuDeallocator(BaseCpuDataBuffer buffer) {
opaqueDataBuffer = buffer.getOpaqueDataBuffer();
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java
index 898a125f2..19ad6f907 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java
@@ -28,6 +28,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder;
+import org.nd4j.nativeblas.OpaqueDataBuffer;
import java.nio.ByteBuffer;
@@ -123,7 +124,7 @@ public class LongBuffer extends BaseCpuDataBuffer {
// we still want this buffer to have native representation
- ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false);
+ ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, DataType.INT64, false);
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, this.pointer, numberOfElements);
Nd4j.getDeallocatorService().pickObject(this);
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java
index 461646311..9d79e6545 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java
@@ -20,11 +20,14 @@ import lombok.NonNull;
import lombok.val;
import org.bytedeco.javacpp.*;
import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.memory.Deallocatable;
+import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOpContext;
import org.nd4j.linalg.api.ops.ExecutionMode;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer;
+import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
@@ -38,14 +41,19 @@ import java.util.List;
*
* @author raver119@gmail.com
*/
-public class CpuOpContext extends BaseOpContext implements OpContext {
+public class CpuOpContext extends BaseOpContext implements OpContext, Deallocatable {
// we might want to have configurable
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
private OpaqueContext context = nativeOps.createGraphContext(1);
+ private final transient long id = Nd4j.getDeallocatorService().nextValue();
+
+ public CpuOpContext() {
+ Nd4j.getDeallocatorService().pickObject(this);
+ }
@Override
public void close() {
- nativeOps.deleteGraphContext(context);
+ // no-op
}
@Override
@@ -136,4 +144,25 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
super.setExecutionMode(mode);
nativeOps.ctxSetExecutionMode(context, mode.ordinal());
}
+
+ @Override
+ public void purge() {
+ super.purge();
+ nativeOps.ctxPurge(context);
+ }
+
+ @Override
+ public String getUniqueId() {
+ return new String("CTX_" + id);
+ }
+
+ @Override
+ public Deallocator deallocator() {
+ return new CpuOpContextDeallocator(this);
+ }
+
+ @Override
+ public int targetDevice() {
+ return 0;
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContextDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContextDeallocator.java
new file mode 100644
index 000000000..621f882bd
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContextDeallocator.java
@@ -0,0 +1,34 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.nd4j.linalg.cpu.nativecpu.ops;
+
+import org.nd4j.linalg.api.memory.Deallocator;
+import org.nd4j.nativeblas.NativeOpsHolder;
+import org.nd4j.nativeblas.OpaqueContext;
+
+public class CpuOpContextDeallocator implements Deallocator {
+ private transient final OpaqueContext context;
+
+ public CpuOpContextDeallocator(CpuOpContext ctx) {
+ context = (OpaqueContext) ctx.contextPointer();
+ }
+
+ @Override
+ public void deallocate() {
+ NativeOpsHolder.getInstance().getDeviceNativeOps().deleteGraphContext(context);
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
index 5522141be..49d088f27 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
@@ -3093,6 +3093,7 @@ public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride);
public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
+public native void ctxPurge(OpaqueContext ptr);
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
@@ -6456,6 +6457,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs);
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs);
+ /**
+ * This method purges fastpath in/out contents and releases all the handles.
+ *
+ * PLEASE NOTE: I/T/B/D args will stay intact
+ */
+ public native void clearFastPath();
+
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
@@ -19169,6 +19177,38 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
}
// #endif
+ /**
+ * solve op. - solve systems of linear equations - general method.
+ *
+ * input params:
+ * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations
+ * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
+ *
+ * boolean args:
+ * 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used
+ *
+ * return value:
+ * tensor with dimension (x * y * z * ::: * M * K) with solutions
+ *
+ */
+// #if NOT_EXCLUDED(OP_solve)
+ @Namespace("nd4j::ops") public static class solve extends DeclarableCustomOp {
+ static { Loader.load(); }
+ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
+ public solve(Pointer p) { super(p); }
+ /** Native array allocator. Access with {@link Pointer#position(long)}. */
+ public solve(long size) { super((Pointer)null); allocateArray(size); }
+ private native void allocateArray(long size);
+ @Override public solve position(long position) {
+ return (solve)super.position(position);
+ }
+
+ public solve() { super((Pointer)null); allocate(); }
+ private native void allocate();
+ public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
+ }
+// #endif
+
/**
* lu op. - make LUP decomposition of given batch of 2D square matricies
*
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml
index 86ce07ff7..c6da5e6f0 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml
@@ -117,7 +117,7 @@
Maximum heap size was set to 8g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
-->
- -Ddtype=float -Xmx8g
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g
diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml
index 5f5c5fa90..6a3cc6eda 100644
--- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml
+++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml
@@ -224,7 +224,7 @@
Depending on a build machine, default value is not always enough.
-->
false
- -Xmx6g
+ -Xmx6g -Dfile.encoding=UTF-8
@@ -296,7 +296,7 @@
Maximum heap size was set to 6g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
-->
- -Xmx6g
+ -Xmx6g -Dfile.encoding=UTF-8
false
false
diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml
index 9d098189f..8d250629a 100644
--- a/nd4j/nd4j-backends/nd4j-tests/pom.xml
+++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml
@@ -252,7 +252,7 @@
Maximum heap size was set to 6g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
-->
- -Ddtype=float -Xmx6g
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx6g
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java
index ab56ae281..e88f195c0 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java
@@ -44,6 +44,11 @@ public class TestOpMapping extends BaseNd4jTest {
return 'c';
}
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 60000L;
+ }
+
@Test
public void testOpMappingCoverage() throws Exception {
Reflections reflections = new Reflections("org.nd4j");
diff --git a/nd4j/nd4j-serde/nd4j-aeron/pom.xml b/nd4j/nd4j-serde/nd4j-aeron/pom.xml
index 87e9347dd..827afb23a 100644
--- a/nd4j/nd4j-serde/nd4j-aeron/pom.xml
+++ b/nd4j/nd4j-serde/nd4j-aeron/pom.xml
@@ -91,7 +91,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
- -Ddtype=float -Xmx8g
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g
diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml
index ddadc2df1..3a768c1a5 100644
--- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml
+++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml
@@ -109,7 +109,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
- -Ddtype=float -Xmx8g
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g
diff --git a/nd4j/nd4j-serde/nd4j-gson/pom.xml b/nd4j/nd4j-serde/nd4j-gson/pom.xml
index f7215436a..f488bfde5 100644
--- a/nd4j/nd4j-serde/nd4j-gson/pom.xml
+++ b/nd4j/nd4j-serde/nd4j-gson/pom.xml
@@ -100,7 +100,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
- -Ddtype=float -Xmx8g
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g
diff --git a/nd4j/nd4j-serde/nd4j-kryo/pom.xml b/nd4j/nd4j-serde/nd4j-kryo/pom.xml
index 25acac26f..02970d5e2 100644
--- a/nd4j/nd4j-serde/nd4j-kryo/pom.xml
+++ b/nd4j/nd4j-serde/nd4j-kryo/pom.xml
@@ -220,7 +220,7 @@
Maximum heap size was set to 6g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
-->
- -Ddtype=float -Xmx6g
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx6g