diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml index 04a1fc0f0..6ce3c9c1f 100644 --- a/arbiter/arbiter-core/pom.xml +++ b/arbiter/arbiter-core/pom.xml @@ -84,6 +84,13 @@ jackson ${nd4j.version} + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java index e7d07f81a..abfedac5e 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; @@ -32,7 +33,7 @@ import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusLis import org.junit.Assert; import org.junit.Test; -public class TestGeneticSearch { +public class TestGeneticSearch extends BaseDL4JTest { public class TestSelectionOperator extends SelectionOperator { @Override diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java index 7f9906abf..24f80bf23 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator; @@ -26,7 +27,7 @@ import java.util.Map; import static org.junit.Assert.*; -public class TestGridSearch { +public class TestGridSearch extends BaseDL4JTest { @Test public void testIndexing() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java index 6f1b336bb..cf5b1e1c7 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java @@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize; import org.apache.commons.math3.distribution.LogNormalDistribution; import org.apache.commons.math3.distribution.NormalDistribution; import org.apache.commons.math3.distribution.UniformIntegerDistribution; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; @@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 02/02/2017. */ -public class TestJson { +public class TestJson extends BaseDL4JTest { protected static ObjectMapper getObjectMapper(JsonFactory factory) { ObjectMapper om = new ObjectMapper(factory); diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java index 305480420..34916ebdc 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; @@ -34,7 +35,7 @@ import java.util.Map; * Test random search on the Branin Function: * http://www.sfu.ca/~ssurjano/branin.html */ -public class TestRandomSearch { +public class TestRandomSearch extends BaseDL4JTest { @Test public void test() throws Exception { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java index 85b048df7..6ca842637 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java @@ -17,12 +17,13 @@ package org.deeplearning4j.arbiter.optimize.distribution; import org.apache.commons.math3.distribution.RealDistribution; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class TestLogUniform { +public class TestLogUniform extends BaseDL4JTest { @Test public void testSimple(){ diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java index 4342f32ef..252b4304f 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; @@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; import org.junit.Test; -public class ArithmeticCrossoverTests { +public class ArithmeticCrossoverTests extends BaseDL4JTest { @Test public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java index 96e3fda7b..50ae7f729 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator; @@ -23,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; import org.junit.Assert; import org.junit.Test; -public class CrossoverOperatorTests { +public class CrossoverOperatorTests extends BaseDL4JTest { @Test public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java index 0b1ec5271..5c1bebb51 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator; import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; @@ -24,7 +25,7 @@ import org.junit.Test; import java.util.Deque; -public class CrossoverPointsGeneratorTests { +public class CrossoverPointsGeneratorTests extends BaseDL4JTest { @Test public void CrossoverPointsGenerator_FixedNumberCrossovers() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java index fbafb37a5..2399d256a 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; @@ -25,7 +26,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; import org.junit.Test; -public class KPointCrossoverTests { +public class KPointCrossoverTests extends BaseDL4JTest { @Test public void KPointCrossover_BelowCrossoverRate_ShouldReturnParent0() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java index ed2bec0ba..6976d8dd4 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; import org.junit.Assert; @@ -24,7 +25,7 @@ import org.junit.Test; import java.util.ArrayList; import java.util.List; -public class ParentSelectionTests { +public class ParentSelectionTests extends BaseDL4JTest { @Test public void ParentSelection_InitializeInstance_ShouldInitPopulation() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java index 2e8a166a7..09b244ab3 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection; import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; @@ -26,7 +27,7 @@ import org.junit.Test; import java.util.ArrayList; import java.util.List; -public class RandomTwoParentSelectionTests { +public class RandomTwoParentSelectionTests extends BaseDL4JTest { @Test public void RandomTwoParentSelection_ShouldReturnTwoDifferentParents() { RandomGenerator rng = new TestRandomGenerator(new int[] {1, 1, 1, 0}, null); diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java index 36c11e00d..32dfb136c 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover; import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; @@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; import org.junit.Test; -public class SinglePointCrossoverTests { +public class SinglePointCrossoverTests extends BaseDL4JTest { @Test public void SinglePointCrossover_BelowCrossoverRate_ShouldReturnParent0() { RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0}); diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java index 170984a68..9efe89620 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; @@ -27,7 +28,7 @@ import org.junit.Assert; import org.junit.Test; import sun.reflect.generics.reflectiveObjects.NotImplementedException; -public class TwoParentsCrossoverOperatorTests { +public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest { class TestTwoParentsCrossoverOperator extends TwoParentsCrossoverOperator { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java index fc4524fdc..76a395c28 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.UniformCrossover; import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; @@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; import org.junit.Test; -public class UniformCrossoverTests { +public class UniformCrossoverTests extends BaseDL4JTest { @Test public void UniformCrossover_BelowCrossoverRate_ShouldReturnParent0() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java index fbd5465c4..ccdb434e8 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.culling; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; @@ -27,7 +28,7 @@ import org.junit.Test; import java.util.ArrayList; import java.util.List; -public class LeastFitCullOperatorTests { +public class LeastFitCullOperatorTests extends BaseDL4JTest { @Test public void LeastFitCullingOperation_ShouldCullLastElements() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java index 4e41268be..093ffd486 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.culling; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; @@ -27,7 +28,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException; import java.util.List; -public class RatioCullOperatorTests { +public class RatioCullOperatorTests extends BaseDL4JTest { class TestRatioCullOperator extends RatioCullOperator { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java index 80773c39e..38e2ba87b 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.mutation; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator; import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; @@ -24,7 +25,7 @@ import org.junit.Test; import java.lang.reflect.Field; import java.util.Arrays; -public class RandomMutationOperatorTests { +public class RandomMutationOperatorTests extends BaseDL4JTest { @Test public void RandomMutationOperator_DefaultBuild_ShouldNotBeNull() { RandomMutationOperator sut = new RandomMutationOperator.Builder().build(); diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java index 760b198b0..e185b8164 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.population; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; @@ -27,7 +28,7 @@ import org.junit.Test; import java.util.List; -public class PopulationModelTests { +public class PopulationModelTests extends BaseDL4JTest { private class TestCullOperator implements CullOperator { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java index e5df87549..1d2b74de9 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.selection; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; @@ -36,7 +37,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException; import static org.junit.Assert.assertArrayEquals; -public class GeneticSelectionOperatorTests { +public class GeneticSelectionOperatorTests extends BaseDL4JTest { private class TestCullOperator implements CullOperator { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java index 5cc61d744..3f64279ee 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.selection; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; @@ -25,7 +26,7 @@ import org.junit.Assert; import org.junit.Test; import sun.reflect.generics.reflectiveObjects.NotImplementedException; -public class SelectionOperatorTests { +public class SelectionOperatorTests extends BaseDL4JTest { private class TestSelectionOperator extends SelectionOperator { public PopulationModel getPopulationModel() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java index 4a203f770..98396a941 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.parameter; import org.apache.commons.math3.distribution.NormalDistribution; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; @@ -25,7 +26,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; -public class TestParameterSpaces { +public class TestParameterSpaces extends BaseDL4JTest { @Test diff --git a/arbiter/arbiter-deeplearning4j/pom.xml b/arbiter/arbiter-deeplearning4j/pom.xml index b163e2ae4..ec7e22d3c 100644 --- a/arbiter/arbiter-deeplearning4j/pom.xml +++ b/arbiter/arbiter-deeplearning4j/pom.xml @@ -63,6 +63,13 @@ gson ${gson.version} + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java index af5d04c0a..7c4ec38f4 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.computationgraph; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.TestUtils; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; @@ -44,7 +45,7 @@ import java.util.Random; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class TestComputationGraphSpace { +public class TestComputationGraphSpace extends BaseDL4JTest { @Test public void testBasic() { diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java index c64a06040..1747b45f9 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.computationgraph; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.conf.updater.AdamSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; @@ -85,7 +86,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @Slf4j -public class TestGraphLocalExecution { +public class TestGraphLocalExecution extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); @@ -126,7 +127,7 @@ public class TestGraphLocalExecution { if(dataApproach == 0){ ds = TestDL4JLocalExecution.MnistDataSource.class; dsP = new Properties(); - dsP.setProperty("minibatch", "8"); + dsP.setProperty("minibatch", "2"); candidateGenerator = new RandomSearchGenerator(mls); } else if(dataApproach == 1) { //DataProvider approach @@ -150,8 +151,8 @@ public class TestGraphLocalExecution { .dataSource(ds, dsP) .modelSaver(new FileModelSaver(modelSave)) .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), - new MaxCandidatesCondition(5)) + .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), + new MaxCandidatesCondition(3)) .build(); IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator())); @@ -159,7 +160,7 @@ public class TestGraphLocalExecution { runner.execute(); List results = runner.getResults(); - assertEquals(5, results.size()); + assertTrue(results.size() > 0); System.out.println("----- COMPLETE - " + results.size() + " results -----"); } @@ -203,8 +204,8 @@ public class TestGraphLocalExecution { OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() .candidateGenerator(candidateGenerator).dataProvider(dataProvider) .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true)) - .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS), - new MaxCandidatesCondition(10)) + .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), + new MaxCandidatesCondition(3)) .build(); IOptimizationRunner runner = new LocalOptimizationRunner(configuration, @@ -223,7 +224,7 @@ public class TestGraphLocalExecution { .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) .l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in") - .setInputTypes(InputType.feedForward(4)) + .setInputTypes(InputType.feedForward(784)) .addLayer("layer0", new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10)) .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) @@ -250,8 +251,8 @@ public class TestGraphLocalExecution { .candidateGenerator(candidateGenerator) .dataProvider(new TestMdsDataProvider(1, 32)) .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true)) - .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS), - new MaxCandidatesCondition(10)) + .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), + new MaxCandidatesCondition(3)) .scoreFunction(ScoreFunctions.testSetAccuracy()) .build(); @@ -279,7 +280,7 @@ public class TestGraphLocalExecution { @Override public Object trainData(Map dataParameters) { try { - DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 10 * batchSize), false, true, true, 12345); + DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 3 * batchSize), false, true, true, 12345); return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying)); } catch (IOException e) { throw new RuntimeException(e); @@ -289,7 +290,7 @@ public class TestGraphLocalExecution { @Override public Object testData(Map dataParameters) { try { - DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 5 * batchSize), false, false, false, 12345); + DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 2 * batchSize), false, false, false, 12345); return new MultiDataSetIteratorAdapter(underlying); } catch (IOException e) { throw new RuntimeException(e); @@ -305,7 +306,7 @@ public class TestGraphLocalExecution { @Test public void testLocalExecutionEarlyStopping() throws Exception { EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() - .epochTerminationConditions(new MaxEpochsTerminationCondition(4)) + .epochTerminationConditions(new MaxEpochsTerminationCondition(2)) .scoreCalculator(new ScoreProvider()) .modelSaver(new InMemoryModelSaver()).build(); Map commands = new HashMap<>(); @@ -348,8 +349,8 @@ public class TestGraphLocalExecution { .dataProvider(dataProvider) .scoreFunction(ScoreFunctions.testSetF1()) .modelSaver(new FileModelSaver(modelSavePath)) - .terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS), - new MaxCandidatesCondition(10)) + .terminationConditions(new MaxTimeCondition(15, TimeUnit.SECONDS), + new MaxCandidatesCondition(3)) .build(); @@ -364,7 +365,7 @@ public class TestGraphLocalExecution { @Override public ScoreCalculator get() { try { - return new DataSetLossCalculatorCG(new MnistDataSetIterator(128, 1280), true); + return new DataSetLossCalculatorCG(new MnistDataSetIterator(4, 8), true); } catch (Exception e){ throw new RuntimeException(e); } diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java index e67e854b8..2b9c5696d 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.computationgraph; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.conf.updater.AdamSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; @@ -79,11 +80,16 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @Slf4j -public class TestGraphLocalExecutionGenetic { +public class TestGraphLocalExecutionGenetic extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public long getTimeoutMilliseconds() { + return 45000L; + } + @Test public void testLocalExecutionDataSources() throws Exception { for (int dataApproach = 0; dataApproach < 3; dataApproach++) { @@ -115,7 +121,7 @@ public class TestGraphLocalExecutionGenetic { if (dataApproach == 0) { ds = TestDL4JLocalExecution.MnistDataSource.class; dsP = new Properties(); - dsP.setProperty("minibatch", "8"); + dsP.setProperty("minibatch", "2"); candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction) .populationModel(new PopulationModel.Builder().populationSize(5).build()) @@ -148,7 +154,7 @@ public class TestGraphLocalExecutionGenetic { .dataSource(ds, dsP) .modelSaver(new FileModelSaver(modelSave)) .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), + .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), new MaxCandidatesCondition(10)) .build(); @@ -157,7 +163,7 @@ public class TestGraphLocalExecutionGenetic { runner.execute(); List results = runner.getResults(); - assertEquals(10, results.size()); + assertTrue(results.size() > 0); System.out.println("----- COMPLETE - " + results.size() + " results -----"); } diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java index d6e77ccf4..12ccc71a5 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.json; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.conf.updater.AdaMaxSpace; @@ -71,7 +72,7 @@ import static org.junit.Assert.assertNotNull; /** * Created by Alex on 14/02/2017. */ -public class TestJson { +public class TestJson extends BaseDL4JTest { @Test public void testMultiLayerSpaceJson() { diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java index 501501e65..ea754990a 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace; @@ -59,7 +60,7 @@ import java.util.concurrent.TimeUnit; // import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener; /** Not strictly a unit test. Rather: part example, part debugging on MNIST */ -public class MNISTOptimizationTest { +public class MNISTOptimizationTest extends BaseDL4JTest { public static void main(String[] args) throws Exception { EarlyStoppingConfiguration esConf = diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java index 0f4d384ba..554ef346e 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator; @@ -72,9 +73,10 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; @Slf4j -public class TestDL4JLocalExecution { +public class TestDL4JLocalExecution extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); @@ -112,7 +114,7 @@ public class TestDL4JLocalExecution { if(dataApproach == 0){ ds = MnistDataSource.class; dsP = new Properties(); - dsP.setProperty("minibatch", "8"); + dsP.setProperty("minibatch", "2"); candidateGenerator = new RandomSearchGenerator(mls); } else if(dataApproach == 1) { //DataProvider approach @@ -136,7 +138,7 @@ public class TestDL4JLocalExecution { .dataSource(ds, dsP) .modelSaver(new FileModelSaver(modelSave)) .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), + .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), new MaxCandidatesCondition(5)) .build(); @@ -146,7 +148,7 @@ public class TestDL4JLocalExecution { runner.execute(); List results = runner.getResults(); - assertEquals(5, results.size()); + assertTrue(results.size() > 0); System.out.println("----- COMPLETE - " + results.size() + " results -----"); } diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java index bb8806f23..a62997a33 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.layers.DenseLayerSpace; @@ -39,7 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; -public class TestErrors { +public class TestErrors extends BaseDL4JTest { @Rule public TemporaryFolder temp = new TemporaryFolder(); @@ -60,7 +61,7 @@ public class TestErrors { CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10)) + .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) .terminationConditions( new MaxCandidatesCondition(5)) @@ -87,7 +88,7 @@ public class TestErrors { CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10)) + .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) .terminationConditions( new MaxCandidatesCondition(5)) @@ -116,7 +117,7 @@ public class TestErrors { CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10)) + .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) .terminationConditions(new MaxCandidatesCondition(5)) .build(); @@ -143,7 +144,7 @@ public class TestErrors { CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10)) + .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) .terminationConditions( new MaxCandidatesCondition(5)) diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java index e3338efcb..6a5458e65 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; import org.apache.commons.lang3.ArrayUtils; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.TestUtils; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; import org.deeplearning4j.arbiter.layers.*; @@ -44,7 +45,7 @@ import java.util.Random; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class TestLayerSpace { +public class TestLayerSpace extends BaseDL4JTest { @Test public void testBasic1() { diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java index 48055ed4b..99dc79f42 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.DL4JConfiguration; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.TestUtils; @@ -86,7 +87,7 @@ import java.util.*; import static org.junit.Assert.*; -public class TestMultiLayerSpace { +public class TestMultiLayerSpace extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java index 01adefd18..f2dc5d180 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java @@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.conf.updater.AdamSpace; import org.deeplearning4j.arbiter.layers.OutputLayerSpace; @@ -60,7 +61,13 @@ import java.util.Map; import static org.junit.Assert.assertEquals; @Slf4j -public class TestScoreFunctions { +public class TestScoreFunctions extends BaseDL4JTest { + + + @Override + public long getTimeoutMilliseconds() { + return 60000L; + } @Test public void testROCScoreFunctions() throws Exception { @@ -107,7 +114,7 @@ public class TestScoreFunctions { List list = runner.getResults(); for (ResultReference rr : list) { - DataSetIterator testIter = new MnistDataSetIterator(32, 2000, false, false, true, 12345); + DataSetIterator testIter = new MnistDataSetIterator(4, 16, false, false, false, 12345); testIter.setPreProcessor(new PreProc(rocType)); OptimizationResult or = rr.getResult(); @@ -141,10 +148,10 @@ public class TestScoreFunctions { } - DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345); + DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345); iter.setPreProcessor(new PreProc(rocType)); - assertEquals(msg, expScore, or.getScore(), 1e-5); + assertEquals(msg, expScore, or.getScore(), 1e-4); } } } @@ -158,7 +165,7 @@ public class TestScoreFunctions { @Override public Object trainData(Map dataParameters) { try { - DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345); + DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345); iter.setPreProcessor(new PreProc(rocType)); return iter; } catch (IOException e){ @@ -169,7 +176,7 @@ public class TestScoreFunctions { @Override public Object testData(Map dataParameters) { try { - DataSetIterator iter = new MnistDataSetIterator(32, 2000, false, false, true, 12345); + DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345); iter.setPreProcessor(new PreProc(rocType)); return iter; } catch (IOException e){ diff --git a/arbiter/arbiter-server/pom.xml b/arbiter/arbiter-server/pom.xml index bdea61138..c4306b967 100644 --- a/arbiter/arbiter-server/pom.xml +++ b/arbiter/arbiter-server/pom.xml @@ -49,6 +49,13 @@ ${junit.version} test + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java index dd2b53409..21e4e402a 100644 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java @@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.server; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; import org.deeplearning4j.arbiter.layers.DenseLayerSpace; @@ -52,7 +53,7 @@ import static org.junit.Assert.assertEquals; * Created by agibsonccc on 3/12/17. */ @Slf4j -public class ArbiterCLIRunnerTest { +public class ArbiterCLIRunnerTest extends BaseDL4JTest { @Test diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java index 8a26e0d5c..534cc986e 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java @@ -311,7 +311,11 @@ public class CSVRecordReaderTest { rr.reset(); fail("Expected exception"); } catch (Exception e){ - e.printStackTrace(); + String msg = e.getMessage(); + String msg2 = e.getCause().getMessage(); + assertTrue(msg, msg.contains("Error during LineRecordReader reset")); + assertTrue(msg2, msg2.contains("Reset not supported from streams")); +// e.printStackTrace(); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java index a209abb91..5027357eb 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java @@ -55,8 +55,7 @@ public class LineReaderTest { @Test public void testLineReader() throws Exception { - String tempDir = System.getProperty("java.io.tmpdir"); - File tmpdir = new File(tempDir, "tmpdir-testLineReader"); + File tmpdir = testDir.newFolder(); if (tmpdir.exists()) tmpdir.delete(); tmpdir.mkdir(); @@ -84,12 +83,6 @@ public class LineReaderTest { } assertEquals(9, count); - - try { - FileUtils.deleteDirectory(tmpdir); - } catch (Exception e) { - e.printStackTrace(); - } } @Test @@ -145,13 +138,6 @@ public class LineReaderTest { assertEquals(2, subset.size()); assertEquals(out3.get(4), subset.get(0)); assertEquals(out3.get(7), subset.get(1)); - - - try { - FileUtils.deleteDirectory(tmpdir); - } catch (Exception e) { - e.printStackTrace(); - } } @Test @@ -177,11 +163,5 @@ public class LineReaderTest { } assertEquals(9, count); - - try { - FileUtils.deleteDirectory(tmpdir); - } catch (Exception e) { - e.printStackTrace(); - } } } diff --git a/datavec/datavec-camel/pom.xml b/datavec/datavec-camel/pom.xml deleted file mode 100644 index 3390242bc..000000000 --- a/datavec/datavec-camel/pom.xml +++ /dev/null @@ -1,116 +0,0 @@ - - - - - - datavec-parent - org.datavec - 1.0.0-SNAPSHOT - - - 4.0.0 - - datavec-camel - - DataVec Camel Component - http://deeplearning4j.org - - - - - org.apache.camel - camel-csv - ${camel.version} - test - - - org.datavec - datavec-api - ${project.version} - - - org.apache.camel - camel-core - ${camel.version} - - - - - - - org.apache.camel - camel-test - ${camel.version} - test - - - - - install - - - - - maven-compiler-plugin - - 1.7 - 1.7 - - - - - maven-resources-plugin - 3.0.1 - - UTF-8 - - - - - - org.apache.camel - camel-package-maven-plugin - ${camel.version} - - - prepare - - prepare-components - - generate-resources - - - - - - - - - - test-nd4j-native - - - test-nd4j-cuda-10.2 - - - diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecComponent.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecComponent.java deleted file mode 100644 index 2abb73558..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecComponent.java +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - -import org.apache.camel.CamelContext; -import org.apache.camel.Endpoint; -import org.apache.camel.impl.UriEndpointComponent; - -import java.util.Map; - -/** - * Represents the component that manages {@link DataVecEndpoint}. - */ -public class DataVecComponent extends UriEndpointComponent { - - public DataVecComponent() { - super(DataVecEndpoint.class); - } - - public DataVecComponent(CamelContext context) { - super(context, DataVecEndpoint.class); - } - - @Override - protected Endpoint createEndpoint(String uri, String remaining, Map parameters) throws Exception { - DataVecEndpoint endpoint = new DataVecEndpoint(uri, this); - setProperties(endpoint, parameters); - endpoint.setInputFormat(remaining); - return endpoint; - } -} diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecConsumer.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecConsumer.java deleted file mode 100644 index e17d022ef..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecConsumer.java +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - - -import org.apache.camel.Exchange; -import org.apache.camel.Processor; -import org.apache.camel.impl.ScheduledPollConsumer; -import org.datavec.api.conf.Configuration; -import org.datavec.api.formats.input.InputFormat; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.InputSplit; - -/** - * The DataVec consumer. - * @author Adam Gibson - */ -public class DataVecConsumer extends ScheduledPollConsumer { - private final DataVecEndpoint endpoint; - private Class inputFormatClazz; - private Class marshallerClazz; - private InputFormat inputFormat; - private Configuration configuration; - private DataVecMarshaller marshaller; - - - public DataVecConsumer(DataVecEndpoint endpoint, Processor processor) { - super(endpoint, processor); - this.endpoint = endpoint; - - try { - inputFormatClazz = (Class) Class.forName(endpoint.getInputFormat()); - inputFormat = inputFormatClazz.newInstance(); - marshallerClazz = (Class) Class.forName(endpoint.getInputMarshaller()); - marshaller = marshallerClazz.newInstance(); - configuration = new Configuration(); - for (String prop : endpoint.getConsumerProperties().keySet()) - configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString()); - - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - //stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split - protected InputSplit inputFromExchange(Exchange exchange) { - return marshaller.getSplit(exchange); - } - - @Override - protected int poll() throws Exception { - Exchange exchange = endpoint.createExchange(); - InputSplit split = inputFromExchange(exchange); - RecordReader reader = inputFormat.createReader(split, configuration); - int numMessagesPolled = 0; - while (reader.hasNext()) { - // create a message body - while (reader.hasNext()) { - exchange.getIn().setBody(reader.next()); - - try { - // send message to next processor in the route - getProcessor().process(exchange); - numMessagesPolled++; // number of messages polled - } finally { - // log exception if an exception occurred and was not handled - if (exchange.getException() != null) { - getExceptionHandler().handleException("Error processing exchange", exchange, - exchange.getException()); - } - } - } - - - } - - return numMessagesPolled; - } -} diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecEndpoint.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecEndpoint.java deleted file mode 100644 index 930f968a1..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecEndpoint.java +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - -import lombok.Data; -import org.apache.camel.Consumer; -import org.apache.camel.Processor; -import org.apache.camel.Producer; -import org.apache.camel.impl.DefaultEndpoint; -import org.apache.camel.spi.Metadata; -import org.apache.camel.spi.UriEndpoint; -import org.apache.camel.spi.UriParam; -import org.apache.camel.spi.UriPath; - -/** - * Represents a DataVec endpoint. - * @author Adam Gibson - */ -@UriEndpoint(scheme = "datavec", title = "datavec", syntax = "datavec:inputFormat/?outputFormat=?&inputMarshaller=?", - consumerClass = DataVecConsumer.class, label = "datavec") -@Data -public class DataVecEndpoint extends DefaultEndpoint { - @UriPath - @Metadata(required = "true") - private String inputFormat; - @UriParam(defaultValue = "") - private String outputFormat; - @UriParam - @Metadata(required = "true") - private String inputMarshaller; - @UriParam(defaultValue = "org.datavec.api.io.converters.SelfWritableConverter") - private String writableConverter; - - public DataVecEndpoint(String uri, DataVecComponent component) { - super(uri, component); - } - - public DataVecEndpoint(String endpointUri) { - super(endpointUri); - } - - public Producer createProducer() throws Exception { - return new DataVecProducer(this); - } - - public Consumer createConsumer(Processor processor) throws Exception { - return new DataVecConsumer(this, processor); - } - - public boolean isSingleton() { - return true; - } - -} diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecProducer.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecProducer.java deleted file mode 100644 index 45697ff1c..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecProducer.java +++ /dev/null @@ -1,109 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - -import org.apache.camel.Exchange; -import org.apache.camel.impl.DefaultProducer; -import org.datavec.api.conf.Configuration; -import org.datavec.api.formats.input.InputFormat; -import org.datavec.api.io.WritableConverter; -import org.datavec.api.io.converters.SelfWritableConverter; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; - -import java.util.ArrayList; -import java.util.Collection; - - -/** - * The DataVec producer. - * Converts input records in to their final form - * based on the input split generated from - * the given exchange. - * - * @author Adam Gibson - */ -public class DataVecProducer extends DefaultProducer { - private Class inputFormatClazz; - private Class marshallerClazz; - private InputFormat inputFormat; - private Configuration configuration; - private WritableConverter writableConverter; - private DataVecMarshaller marshaller; - - - public DataVecProducer(DataVecEndpoint endpoint) { - super(endpoint); - if (endpoint.getInputFormat() != null) { - try { - inputFormatClazz = (Class) Class.forName(endpoint.getInputFormat()); - inputFormat = inputFormatClazz.newInstance(); - marshallerClazz = (Class) Class.forName(endpoint.getInputMarshaller()); - Class converterClazz = - (Class) Class.forName(endpoint.getWritableConverter()); - writableConverter = converterClazz.newInstance(); - marshaller = marshallerClazz.newInstance(); - configuration = new Configuration(); - for (String prop : endpoint.getConsumerProperties().keySet()) - configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString()); - - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - } - - - //stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split - protected InputSplit inputFromExchange(Exchange exchange) { - return marshaller.getSplit(exchange); - } - - - @Override - public void process(Exchange exchange) throws Exception { - InputSplit split = inputFromExchange(exchange); - RecordReader reader = inputFormat.createReader(split, configuration); - Collection> newRecord = new ArrayList<>(); - if (!(writableConverter instanceof SelfWritableConverter)) { - newRecord = new ArrayList<>(); - while (reader.hasNext()) { - Collection newRecordAdd = new ArrayList<>(); - // create a message body - Collection next = reader.next(); - for (Writable writable : next) { - newRecordAdd.add(writableConverter.convert(writable)); - } - - - newRecord.add(newRecordAdd); - } - } else { - while (reader.hasNext()) { - // create a message body - Collection next = reader.next(); - newRecord.add(next); - } - } - - - exchange.getIn().setBody(newRecord); - exchange.getOut().setBody(newRecord); - } -} diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/csv/marshaller/ListStringInputMarshaller.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/csv/marshaller/ListStringInputMarshaller.java deleted file mode 100644 index 9ff8913be..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/csv/marshaller/ListStringInputMarshaller.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component.csv.marshaller; - -import org.apache.camel.Exchange; -import org.datavec.api.split.InputSplit; -import org.datavec.api.split.ListStringSplit; -import org.datavec.camel.component.DataVecMarshaller; - -import java.util.List; - -/** - * Marshals List> - * - * @author Adam Gibson - */ -public class ListStringInputMarshaller implements DataVecMarshaller { - /** - * @param exchange - * @return - */ - @Override - public InputSplit getSplit(Exchange exchange) { - List> data = (List>) exchange.getIn().getBody(); - InputSplit listSplit = new ListStringSplit(data); - return listSplit; - } -} diff --git a/datavec/datavec-camel/src/main/resources/META-INF/services/org/apache/camel/component/datavec b/datavec/datavec-camel/src/main/resources/META-INF/services/org/apache/camel/component/datavec deleted file mode 100644 index 2a324f975..000000000 --- a/datavec/datavec-camel/src/main/resources/META-INF/services/org/apache/camel/component/datavec +++ /dev/null @@ -1 +0,0 @@ -class=org.datavec.camel.component.DataVecComponent diff --git a/datavec/datavec-camel/src/test/java/org/datavec/camel/component/DataVecComponentTest.java b/datavec/datavec-camel/src/test/java/org/datavec/camel/component/DataVecComponentTest.java deleted file mode 100644 index de4aa35cf..000000000 --- a/datavec/datavec-camel/src/test/java/org/datavec/camel/component/DataVecComponentTest.java +++ /dev/null @@ -1,82 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - -import org.apache.camel.builder.RouteBuilder; -import org.apache.camel.component.mock.MockEndpoint; -import org.apache.camel.test.junit4.CamelTestSupport; -import org.apache.commons.io.FileUtils; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.records.reader.impl.csv.CSVRecordReader; -import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.io.ClassPathResource; - -import java.io.File; -import java.util.ArrayList; -import java.util.Collection; - -public class DataVecComponentTest extends CamelTestSupport { - - @ClassRule - public static TemporaryFolder testDir = new TemporaryFolder(); - private static File dir; - private static File irisFile; - - - @BeforeClass - public static void before() throws Exception { - dir = testDir.newFolder(); - File iris = new ClassPathResource("iris.dat").getFile(); - irisFile = new File(dir, "iris.dat"); - FileUtils.copyFile(iris, irisFile ); - } - - - - @Test - public void testDataVec() throws Exception { - MockEndpoint mock = getMockEndpoint("mock:result"); - //1 - mock.expectedMessageCount(1); - - RecordReader reader = new CSVRecordReader(); - reader.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile())); - Collection> recordAssertion = new ArrayList<>(); - while (reader.hasNext()) - recordAssertion.add(reader.next()); - mock.expectedBodiesReceived(recordAssertion); - assertMockEndpointsSatisfied(); - } - - @Override - protected RouteBuilder createRouteBuilder() throws Exception { - - - return new RouteBuilder() { - public void configure() { - from("file:" + dir.getAbsolutePath() + "?fileName=iris.dat&noop=true").unmarshal().csv() - .to("datavec://org.datavec.api.formats.input.impl.ListStringInputFormat?inputMarshaller=org.datavec.camel.component.ListStringInputMarshaller&writableConverter=org.datavec.api.io.converters.SelfWritableConverter") - .to("mock:result"); - } - }; - } -} diff --git a/datavec/datavec-data/datavec-data-image/pom.xml b/datavec/datavec-data/datavec-data-image/pom.xml index fc85482a4..0e4267659 100644 --- a/datavec/datavec-data/datavec-data-image/pom.xml +++ b/datavec/datavec-data/datavec-data-image/pom.xml @@ -37,11 +37,6 @@ ${logback.version} test - - org.nd4j - nd4j-buffer - ${nd4j.version} - com.github.jai-imageio jai-imageio-core diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java index 5f634bab8..b8fa0c43d 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java @@ -570,6 +570,7 @@ public class TestNativeImageLoader { try(InputStream is = new FileInputStream(f)){ nil.asMatrix(is); + fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); assertTrue(msg, msg.contains("decode image")); @@ -577,6 +578,7 @@ public class TestNativeImageLoader { try(InputStream is = new FileInputStream(f)){ nil.asImageMatrix(is); + fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); assertTrue(msg, msg.contains("decode image")); @@ -584,6 +586,7 @@ public class TestNativeImageLoader { try(InputStream is = new FileInputStream(f)){ nil.asRowVector(is); + fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); assertTrue(msg, msg.contains("decode image")); @@ -592,6 +595,7 @@ public class TestNativeImageLoader { try(InputStream is = new FileInputStream(f)){ INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32); nil.asMatrixView(is, arr); + fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); assertTrue(msg, msg.contains("decode image")); diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java index 3f0155097..9825e6899 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java @@ -66,9 +66,9 @@ public class JsonYamlTest { String asJson = itp.toJson(); String asYaml = itp.toYaml(); - System.out.println(asJson); - System.out.println("\n\n\n"); - System.out.println(asYaml); +// System.out.println(asJson); +// System.out.println("\n\n\n"); +// System.out.println(asYaml); ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3); ImageWritable imgJson = new ImageWritable(img.getFrame().clone()); diff --git a/datavec/datavec-geo/pom.xml b/datavec/datavec-data/datavec-geo/pom.xml similarity index 100% rename from datavec/datavec-geo/pom.xml rename to datavec/datavec-data/datavec-geo/pom.xml diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java diff --git a/datavec/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java similarity index 100% rename from datavec/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java rename to datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java diff --git a/datavec/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java similarity index 100% rename from datavec/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java rename to datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java diff --git a/datavec/datavec-hadoop/pom.xml b/datavec/datavec-data/datavec-hadoop/pom.xml similarity index 100% rename from datavec/datavec-hadoop/pom.xml rename to datavec/datavec-data/datavec-hadoop/pom.xml diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/IndexToKey.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/IndexToKey.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/IndexToKey.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/IndexToKey.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/index/LongIndexToKey.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/index/LongIndexToKey.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/index/LongIndexToKey.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/index/LongIndexToKey.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/conf/TestConfigurationUtil.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/conf/TestConfigurationUtil.java similarity index 100% rename from datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/conf/TestConfigurationUtil.java rename to datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/conf/TestConfigurationUtil.java diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java similarity index 100% rename from datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java rename to datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java similarity index 100% rename from datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java rename to datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java similarity index 100% rename from datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java rename to datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java similarity index 100% rename from datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java rename to datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java diff --git a/datavec/datavec-hadoop/src/test/resources/log4j.properties b/datavec/datavec-data/datavec-hadoop/src/test/resources/log4j.properties similarity index 100% rename from datavec/datavec-hadoop/src/test/resources/log4j.properties rename to datavec/datavec-data/datavec-hadoop/src/test/resources/log4j.properties diff --git a/datavec/datavec-hadoop/src/test/resources/logback.xml b/datavec/datavec-data/datavec-hadoop/src/test/resources/logback.xml similarity index 100% rename from datavec/datavec-hadoop/src/test/resources/logback.xml rename to datavec/datavec-data/datavec-hadoop/src/test/resources/logback.xml diff --git a/datavec/datavec-perf/pom.xml b/datavec/datavec-perf/pom.xml deleted file mode 100644 index a51b9aba1..000000000 --- a/datavec/datavec-perf/pom.xml +++ /dev/null @@ -1,65 +0,0 @@ - - - - - - - datavec-parent - org.datavec - 1.0.0-SNAPSHOT - - 4.0.0 - - datavec-perf - - datavec-perf - - - UTF-8 - 1.7 - 1.7 - - - - - org.slf4j - slf4j-api - ${slf4j.version} - - - org.datavec - datavec-data-image - ${project.version} - test - - - org.datavec - datavec-api - ${project.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-10.2 - - - diff --git a/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/IOTiming.java b/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/IOTiming.java deleted file mode 100644 index aa82194e1..000000000 --- a/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/IOTiming.java +++ /dev/null @@ -1,112 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.perf.timing; - -import lombok.val; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.InputStreamInputSplit; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.performance.PerformanceTracker; -import org.nd4j.linalg.memory.MemcpyDirection; - -import java.io.BufferedInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.InputStream; -import java.util.List; -import java.util.Map; - - - -/** - * Timing components of a data vec pipeline - * consisting of: - * {@link RecordReader}, {@link InputStreamInputSplit} - * (note that this uses input stream input split, - * the record reader must support {@link InputStreamInputSplit} for this to work) - * - * @author Adam Gibson - */ -public class IOTiming { - - - /** - * Returns statistics for components of a datavec pipeline - * averaged over the specified number of times - * @param nTimes the number of times to run the pipeline for averaging - * @param recordReader the record reader - * @param file the file to read - * @param function the function - * @return the averaged {@link TimingStatistics} for input/output on a record - * reader and ndarray creation (based on the given function - * @throws Exception - */ - public static TimingStatistics averageFileRead(long nTimes, RecordReader recordReader, File file, INDArrayCreationFunction function) throws Exception { - TimingStatistics timingStatistics = null; - for(int i = 0; i < nTimes; i++) { - TimingStatistics timingStatistics1 = timeNDArrayCreation(recordReader,new BufferedInputStream(new FileInputStream(file)),function); - if(timingStatistics == null) - timingStatistics = timingStatistics1; - else { - timingStatistics = timingStatistics.add(timingStatistics1); - } - - } - - return timingStatistics.average(nTimes); - } - - /** - * - * @param reader - * @param inputStream - * @param function - * @return - * @throws Exception - */ - public static TimingStatistics timeNDArrayCreation(RecordReader reader, - InputStream inputStream, - INDArrayCreationFunction function) throws Exception { - - - reader.initialize(new InputStreamInputSplit(inputStream)); - long longNanos = System.nanoTime(); - List next = reader.next(); - long endNanos = System.nanoTime(); - long etlDiff = endNanos - longNanos; - long startArrCreation = System.nanoTime(); - INDArray arr = function.createFromRecord(next); - long endArrCreation = System.nanoTime(); - long endCreationDiff = endArrCreation - startArrCreation; - Map> currentBandwidth = PerformanceTracker.getInstance().getCurrentBandwidth(); - val bw = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE); - val deviceToHost = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE); - - return TimingStatistics.builder() - .diskReadingTimeNanos(etlDiff) - .bandwidthNanosHostToDevice(bw) - .bandwidthDeviceToHost(deviceToHost) - .ndarrayCreationTimeNanos(endCreationDiff) - .build(); - } - - public interface INDArrayCreationFunction { - INDArray createFromRecord(List record); - } - -} diff --git a/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/TimingStatistics.java b/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/TimingStatistics.java deleted file mode 100644 index 615d9659a..000000000 --- a/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/TimingStatistics.java +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.perf.timing; - -import lombok.Builder; -import lombok.Data; - - -/** - * The timing statistics for a data pipeline including: - * ndarray creation time in nanoseconds - * disk reading time in nanoseconds - * bandwidth used in host to device in nano seconds - * bandwidth device to host in nanoseconds - * - * @author Adam Gibson - */ -@Builder -@Data -public class TimingStatistics { - - private long ndarrayCreationTimeNanos; - private long diskReadingTimeNanos; - private long bandwidthNanosHostToDevice; - private long bandwidthDeviceToHost; - - - /** - * Accumulate the given statistics - * @param timingStatistics the statistics to add - * @return the added statistics - */ - public TimingStatistics add(TimingStatistics timingStatistics) { - return TimingStatistics.builder() - .ndarrayCreationTimeNanos(ndarrayCreationTimeNanos + timingStatistics.ndarrayCreationTimeNanos) - .bandwidthNanosHostToDevice(bandwidthNanosHostToDevice + timingStatistics.bandwidthNanosHostToDevice) - .diskReadingTimeNanos(diskReadingTimeNanos + timingStatistics.diskReadingTimeNanos) - .bandwidthDeviceToHost(bandwidthDeviceToHost + timingStatistics.bandwidthDeviceToHost) - .build(); - } - - - /** - * Average the results relative to the number of n. - * This method is meant to be used alongside - * {@link #add(TimingStatistics)} - * accumulated a number of times - * @param n n the number of elements - * @return the averaged results - */ - public TimingStatistics average(long n) { - return TimingStatistics.builder() - .ndarrayCreationTimeNanos(ndarrayCreationTimeNanos / n) - .bandwidthDeviceToHost(bandwidthDeviceToHost / n) - .diskReadingTimeNanos(diskReadingTimeNanos / n) - .bandwidthNanosHostToDevice(bandwidthNanosHostToDevice / n) - .build(); - } - -} diff --git a/datavec/datavec-perf/src/test/java/org/datavec/datavec/timing/IOTimingTest.java b/datavec/datavec-perf/src/test/java/org/datavec/datavec/timing/IOTimingTest.java deleted file mode 100644 index 60c206232..000000000 --- a/datavec/datavec-perf/src/test/java/org/datavec/datavec/timing/IOTimingTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.datavec.timing; - -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; -import org.datavec.image.loader.NativeImageLoader; -import org.datavec.image.recordreader.ImageRecordReader; -import org.datavec.perf.timing.IOTiming; -import org.datavec.perf.timing.TimingStatistics; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.io.ClassPathResource; - -import java.util.List; - -public class IOTimingTest { - - @Test - public void testTiming() throws Exception { - final RecordReader image = new ImageRecordReader(28,28); - final NativeImageLoader nativeImageLoader = new NativeImageLoader(28,28); - - TimingStatistics timingStatistics = IOTiming.timeNDArrayCreation(image, new ClassPathResource("datavec-perf/largestblobtest.jpg").getInputStream(), new IOTiming.INDArrayCreationFunction() { - @Override - public INDArray createFromRecord(List record) { - NDArrayWritable imageWritable = (NDArrayWritable) record.get(0); - return imageWritable.get(); - } - }); - - System.out.println(timingStatistics); - - TimingStatistics timingStatistics1 = IOTiming.averageFileRead(1000,image,new ClassPathResource("datavec-perf/largestblobtest.jpg").getFile(), new IOTiming.INDArrayCreationFunction() { - @Override - public INDArray createFromRecord(List record) { - NDArrayWritable imageWritable = (NDArrayWritable) record.get(0); - return imageWritable.get(); - } - }); - - System.out.println(timingStatistics1); - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java index 64bb2cc89..9bc445a49 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java @@ -60,7 +60,7 @@ public class CSVSparkTransformTest { Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values)); INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); assertTrue(fromBase64.isVector()); - System.out.println("Base 64ed array " + fromBase64); +// System.out.println("Base 64ed array " + fromBase64); } @Test @@ -125,7 +125,7 @@ public class CSVSparkTransformTest { SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord); assertNotNull(transformed.getRecords()); - System.out.println(transformed); +// System.out.println(transformed); } @@ -153,7 +153,8 @@ public class CSVSparkTransformTest { new SingleCSVRecord(data2))); final CSVSparkTransform transform = new CSVSparkTransform(transformProcess); - System.out.println(transform.transformSequenceIncremental(batchCsvRecord)); +// System.out.println(transform.transformSequenceIncremental(batchCsvRecord)); + transform.transformSequenceIncremental(batchCsvRecord); assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank()); } diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java index 34075006c..c3474ab85 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java @@ -54,7 +54,7 @@ public class ImageSparkTransformTest { Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord); INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); - System.out.println("Base 64ed array " + fromBase64); +// System.out.println("Base 64ed array " + fromBase64); assertEquals(1, fromBase64.size(0)); } @@ -78,7 +78,7 @@ public class ImageSparkTransformTest { Base64NDArrayBody body = imgSparkTransform.toArray(batch); INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); - System.out.println("Base 64ed array " + fromBase64); +// System.out.println("Base 64ed array " + fromBase64); assertEquals(3, fromBase64.size(0)); } } diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java index bfae23358..049e08e47 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java @@ -120,7 +120,7 @@ public class ImageSparkTransformServerTest { INDArray batchResult = getNDArray(jsonNodeBatch); assertEquals(3, batchResult.size(0)); - System.out.println(array); +// System.out.println(array); } @Test @@ -136,7 +136,7 @@ public class ImageSparkTransformServerTest { INDArray batchResult = getNDArray(jsonNode); assertEquals(3, batchResult.size(0)); - System.out.println(batchResult); +// System.out.println(batchResult); } @Test @@ -153,7 +153,7 @@ public class ImageSparkTransformServerTest { INDArray result = getNDArray(jsonNode); assertEquals(1, result.size(0)); - System.out.println(result); +// System.out.println(result); } public INDArray getNDArray(JsonNode node) throws IOException { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java index 05058fea8..d251d18d1 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java @@ -72,7 +72,9 @@ public class TestAnalysis extends BaseSparkTest { DataAnalysis da = AnalyzeSpark.analyze(schema, rdd); String daString = da.toString(); - System.out.println(da); +// System.out.println(da); + da.toJson(); + da.toString(); List ca = da.getColumnAnalysis(); assertEquals(5, ca.size()); @@ -151,7 +153,7 @@ public class TestAnalysis extends BaseSparkTest { assertEquals(1, countD[countD.length - 1]); File f = Files.createTempFile("datavec_spark_analysis_UITest", ".html").toFile(); - System.out.println(f.getAbsolutePath()); +// System.out.println(f.getAbsolutePath()); f.deleteOnExit(); HtmlAnalysis.createHtmlAnalysisFile(da, f); } @@ -210,7 +212,7 @@ public class TestAnalysis extends BaseSparkTest { for( int i=1; i<10; i++ ){ counter.merge(counters.get(i)); sparkCounter.merge(sparkCounters.get(i)); - System.out.println(); +// System.out.println(); } assertEquals(sc1.sampleStdev(), counter.getStddev(false), 1e-6); assertEquals(sparkCounter.sampleStdev(), counter.getStddev(false), 1e-6); @@ -356,7 +358,9 @@ public class TestAnalysis extends BaseSparkTest { JavaRDD> rdd = sc.parallelize(data); DataAnalysis da = AnalyzeSpark.analyze(s, rdd); - System.out.println(da); +// System.out.println(da); + da.toString(); + da.toJson(); } } diff --git a/datavec/pom.xml b/datavec/pom.xml index ff4bfa51b..c706469a9 100644 --- a/datavec/pom.xml +++ b/datavec/pom.xml @@ -59,16 +59,12 @@ datavec-api datavec-data - datavec-geo - datavec-hadoop datavec-spark - datavec-camel datavec-local datavec-spark-inference-parent datavec-jdbc datavec-excel datavec-arrow - datavec-perf datavec-python diff --git a/deeplearning4j/deeplearning4j-util/pom.xml b/deeplearning4j/deeplearning4j-common-tests/pom.xml similarity index 57% rename from deeplearning4j/deeplearning4j-util/pom.xml rename to deeplearning4j/deeplearning4j-common-tests/pom.xml index b49239a9e..825c55ca5 100644 --- a/deeplearning4j/deeplearning4j-util/pom.xml +++ b/deeplearning4j/deeplearning4j-common-tests/pom.xml @@ -1,5 +1,6 @@ + - - deeplearning4j-parent @@ -23,36 +24,45 @@ 4.0.0 - deeplearning4j-util - jar - - deeplearning4j-util - http://maven.apache.org - - - UTF-8 - + deeplearning4j-common-tests - org.nd4j - nd4j-api - ${nd4j.version} + junit + junit + provided org.nd4j - nd4j-common - ${nd4j.version} + nd4j-api + ${project.version} - test-nd4j-native + + + org.nd4j + nd4j-native + ${project.version} + test + + + test-nd4j-cuda-10.2 + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + test + + - + + \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java similarity index 67% rename from deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/BaseDL4JTest.java rename to deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java index 786c7ea93..e9b609c45 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/BaseDL4JTest.java +++ b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,7 +23,9 @@ import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataBuffer; +import org.junit.rules.Timeout; +import org.nd4j.base.Preconditions; +import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; @@ -30,22 +33,41 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.profiler.ProfilerConfig; import java.lang.management.ManagementFactory; -import java.lang.management.ThreadMXBean; import java.util.List; import java.util.Map; import java.util.Properties; -import static org.junit.Assert.assertNull; +import static org.junit.Assume.assumeTrue; @Slf4j -public class BaseDL4JTest { +public abstract class BaseDL4JTest { @Rule public TestName name = new TestName(); + @Rule + public Timeout timeout = Timeout.millis(getTimeoutMilliseconds()); protected long startTime; protected int threadCountBefore; + private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors(); + + /** + * Override this to specify the number of threads for C++ execution, via + * {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)} + * @return Number of threads to use for C++ op execution + */ + public int numThreads(){ + return DEFAULT_THREADS; + } + + /** + * Override this method to set the default timeout for methods in the test class + */ + public long getTimeoutMilliseconds(){ + return 30000; + } + /** * Override this to set the profiling mode for the tests defined in the child class */ @@ -64,12 +86,45 @@ public class BaseDL4JTest { return getDataType(); } + protected Boolean integrationTest; + + /** + * @return True if integration tests maven profile is enabled, false otherwise. + */ + public boolean isIntegrationTests(){ + if(integrationTest == null){ + String prop = System.getenv("DL4J_INTEGRATION_TESTS"); + integrationTest = Boolean.parseBoolean(prop); + } + return integrationTest; + } + + /** + * Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled. + * This can be used to dynamically skip integration tests when the integration test profile is not enabled. + * Note that the integration test profile is not enabled by default - "integration-tests" profile + */ + public void skipUnlessIntegrationTests(){ + assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests()); + } + @Before public void beforeTest(){ log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); + //Suppress ND4J initialization - don't need this logged for every test... + System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false"); + System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true"); Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); + Nd4j.getExecutioner().enableDebugMode(false); + Nd4j.getExecutioner().enableVerboseMode(false); + int numThreads = numThreads(); + Preconditions.checkState(numThreads > 0, "Number of threads must be > 0"); + if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) { + Nd4j.getEnvironment().setMaxMasterThreads(numThreads); + } startTime = System.currentTimeMillis(); threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); } diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml index 65830d341..90c88d4c3 100644 --- a/deeplearning4j/deeplearning4j-core/pom.xml +++ b/deeplearning4j/deeplearning4j-core/pom.xml @@ -95,6 +95,12 @@ junit test + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + org.nd4j @@ -147,6 +153,17 @@ ${jaxb.version} provided + + + com.github.oshi + oshi-json + ${oshi.version} + + + com.github.oshi + oshi-core + ${oshi.version} + diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/DeviceMetric.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/DeviceMetric.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/DeviceMetric.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/DeviceMetric.java diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/DiskInfo.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/DiskInfo.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/DiskInfo.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/DiskInfo.java diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/HardwareMetric.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/HardwareMetric.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/HardwareMetric.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/HardwareMetric.java diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemInfoFilePrintListener.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoFilePrintListener.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemInfoFilePrintListener.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoFilePrintListener.java diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java diff --git a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/MovingWindowMatrix.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/MovingWindowMatrix.java old mode 100755 new mode 100644 similarity index 100% rename from deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/MovingWindowMatrix.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/MovingWindowMatrix.java diff --git a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/ThreadUtils.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ThreadUtils.java similarity index 100% rename from deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/ThreadUtils.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ThreadUtils.java diff --git a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/UIDProvider.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/UIDProvider.java similarity index 100% rename from deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/UIDProvider.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/UIDProvider.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java index e3923c4ff..1b769a32d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java @@ -158,7 +158,7 @@ public class LayerHelperValidationUtil { double d2 = arr2.dup('c').getDouble(idx); System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE); } - assertTrue(s + layerName + "activations - max RE: " + maxRE, maxRE < t.getMaxRelError()); + assertTrue(s + layerName + " activations - max RE: " + maxRE, maxRE < t.getMaxRelError()); log.info("Forward pass, max relative error: " + layerName + " - " + maxRE); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java index df072b64f..d90ce628b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -75,7 +75,6 @@ public class TestUtils { } public static ComputationGraph testModelSerialization(ComputationGraph net){ - ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java index 7a057afc6..362e099e4 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java @@ -20,11 +20,9 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.base.MnistFetcher; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; +import org.junit.*; import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.dataset.DataSet; @@ -47,6 +45,8 @@ public class MnistFetcherTest extends BaseDL4JTest { @ClassRule public static TemporaryFolder testDir = new TemporaryFolder(); + @Rule + public Timeout timeout = Timeout.seconds(300); @BeforeClass public static void setup() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java index 3bd1bd37f..6b3047aa5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java @@ -16,6 +16,7 @@ package org.deeplearning4j.datasets.datavec; +import org.junit.rules.Timeout; import org.nd4j.shade.guava.io.Files; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; @@ -70,6 +71,9 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point; @Slf4j public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + @Override public DataType getDataType(){ return DataType.FLOAT; @@ -1002,7 +1006,9 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { for (RecordMetaData m : meta) { Record r = csv.loadFromMetaData(m); INDArray row = ds.getFeatures().getRow(i); - System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row); + if(i <= 3) { + System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row); + } for (int j = 0; j < 4; j++) { double exp = r.getRecord().get(j).toDouble(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java index 1e82a4783..897876112 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java @@ -17,6 +17,7 @@ package org.deeplearning4j.datasets.datavec; +import org.junit.rules.Timeout; import org.nd4j.shade.guava.io.Files; import org.apache.commons.compress.utils.IOUtils; import org.apache.commons.io.FileUtils; @@ -69,6 +70,9 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); + @Rule + public Timeout timeout = Timeout.seconds(300); + @Test public void testsBasic() throws Exception { //Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java index 818bb752f..1815dff73 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java @@ -17,7 +17,9 @@ package org.deeplearning4j.datasets.fetchers; import org.deeplearning4j.BaseDL4JTest; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import java.io.File; @@ -28,6 +30,9 @@ import static org.junit.Assert.assertTrue; */ public class SvhnDataFetcherTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(600); + @Test public void testSvhnDataFetcher() throws Exception { SvhnDataFetcher fetch = new SvhnDataFetcher(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java index 5a73e8ba8..3e5015356 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java @@ -78,8 +78,16 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest { @Test public void hasNextWithResetAndLoad() throws Exception { + int[] prefetchSizes; + if(isIntegrationTests()){ + prefetchSizes = new int[]{2, 3, 4, 5, 6, 7, 8}; + } else { + prefetchSizes = new int[]{2, 3, 8}; + } + + for (int iter = 0; iter < ITERATIONS; iter++) { - for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) { + for(int prefetchSize : prefetchSizes){ AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize); TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL); int cnt = 0; @@ -161,8 +169,14 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest { @Test public void testVariableTimeSeries1() throws Exception { + int numBatches = isIntegrationTests() ? 1000 : 100; + int batchSize = isIntegrationTests() ? 32 : 8; + int timeStepsMin = 10; + int timeStepsMax = isIntegrationTests() ? 500 : 100; + int valuesPerTimestep = isIntegrationTests() ? 128 : 16; + AsyncDataSetIterator adsi = new AsyncDataSetIterator( - new VariableTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10), 2, true); + new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true); for (int e = 0; e < 10; e++) { int cnt = 0; @@ -183,7 +197,7 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest { } adsi.reset(); - log.info("Epoch {} finished...", e); +// log.info("Epoch {} finished...", e); } } @@ -215,7 +229,7 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest { } adsi.reset(); - log.info("Epoch {} finished...", e); +// log.info("Epoch {} finished...", e); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java index 68381bec7..d6a0d1fe1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java @@ -18,21 +18,10 @@ package org.deeplearning4j.datasets.iterator; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.records.reader.impl.csv.CSVRecordReader; -import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; -import org.datavec.api.split.NumberedFileInputSplit; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator; import org.junit.Test; import org.nd4j.linalg.dataset.api.MultiDataSet; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization; -import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize; - -import java.util.Arrays; import static org.junit.Assert.assertEquals; @@ -49,7 +38,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest { */ @Test public void testVariableTimeSeries1() throws Exception { - val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10); + int numBatches = isIntegrationTests() ? 1000 : 100; + int batchSize = isIntegrationTests() ? 32 : 8; + int timeStepsMin = 10; + int timeStepsMax = isIntegrationTests() ? 500 : 100; + int valuesPerTimestep = isIntegrationTests() ? 128 : 16; + + val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10); iterator.reset(); iterator.hasNext(); val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true); @@ -81,7 +76,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest { @Test public void testVariableTimeSeries2() throws Exception { - val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10); + int numBatches = isIntegrationTests() ? 1000 : 100; + int batchSize = isIntegrationTests() ? 32 : 8; + int timeStepsMin = 10; + int timeStepsMax = isIntegrationTests() ? 500 : 100; + int valuesPerTimestep = isIntegrationTests() ? 128 : 16; + + val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10); for (int e = 0; e < 10; e++) { iterator.reset(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java index 6ed7819df..5201b3f56 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java @@ -57,6 +57,11 @@ import static org.junit.Assert.*; public class DataSetIteratorTest extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000; + } + @Test public void testBatchSizeOfOneIris() throws Exception { //Test for (a) iterators returning correct number of examples, and @@ -190,7 +195,7 @@ public class DataSetIteratorTest extends BaseDL4JTest { INDArray output = model.output(dataTest.getFeatures()); Evaluation eval = new Evaluation(outputNum); eval.eval(dataTest.getLabels(), output); - System.out.println(eval.stats()); +// System.out.println(eval.stats()); } @Test @@ -257,7 +262,7 @@ public class DataSetIteratorTest extends BaseDL4JTest { INDArray output = model.output(testDS.getFeatures()); eval.eval(testDS.getLabels(), output); } - System.out.println(eval.stats(true)); +// System.out.println(eval.stats(true)); listener.exportScores(System.out); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java index b2b32c715..f37642c24 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java @@ -22,7 +22,9 @@ import org.datavec.api.split.FileSplit; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.util.TestDataSetConsumer; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; @@ -37,6 +39,9 @@ import static org.junit.Assert.*; public class MultipleEpochsIteratorTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + @Test public void testNextAndReset() throws Exception { int epochs = 3; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java index 0c1049845..c5ac04901 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java @@ -19,7 +19,9 @@ package org.deeplearning4j.datasets.iterator; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -33,6 +35,9 @@ import static org.junit.Assert.*; @Slf4j public class TestEmnistDataSetIterator extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(600); + @Override public DataType getDataType(){ return DataType.FLOAT; @@ -41,17 +46,17 @@ public class TestEmnistDataSetIterator extends BaseDL4JTest { @Test public void testEmnistDataSetIterator() throws Exception { - // EmnistFetcher fetcher = new EmnistFetcher(EmnistDataSetIterator.Set.COMPLETE); - // File baseEmnistDir = fetcher.getFILE_DIR(); - // if(baseEmnistDir.exists()){ - // FileUtils.deleteDirectory(baseEmnistDir); - // } - // assertFalse(baseEmnistDir.exists()); - int batchSize = 128; - for (EmnistDataSetIterator.Set s : EmnistDataSetIterator.Set.values()) { + EmnistDataSetIterator.Set[] sets; + if(isIntegrationTests()){ + sets = EmnistDataSetIterator.Set.values(); + } else { + sets = new EmnistDataSetIterator.Set[]{EmnistDataSetIterator.Set.MNIST, EmnistDataSetIterator.Set.LETTERS}; + } + + for (EmnistDataSetIterator.Set s : sets) { boolean isBalanced = EmnistDataSetIterator.isBalanced(s); int numLabels = EmnistDataSetIterator.numLabels(s); INDArray labelCounts = null; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/VariableMultiTimeseriesGenerator.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/VariableMultiTimeseriesGenerator.java index 04ad55bc1..17642b74c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/VariableMultiTimeseriesGenerator.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/VariableMultiTimeseriesGenerator.java @@ -68,8 +68,8 @@ public class VariableMultiTimeseriesGenerator implements MultiDataSetIterator { int localMaxima = isFirst && firstMaxima > 0 ? firstMaxima : minTS == maxTS ? minTS : rng.nextInt(maxTS - minTS) + minTS; - if (isFirst) - log.info("Local maxima: {}", localMaxima); +// if (isFirst) +// log.info("Local maxima: {}", localMaxima); isFirst = false; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/VariableTimeseriesGenerator.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/VariableTimeseriesGenerator.java index 9d2eb17ff..46dbbac9c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/VariableTimeseriesGenerator.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/VariableTimeseriesGenerator.java @@ -69,8 +69,8 @@ public class VariableTimeseriesGenerator implements DataSetIterator { int localMaxima = isFirst && firstMaxima > 0 ? firstMaxima : minTS == maxTS ? minTS : rng.nextInt(maxTS - minTS) + minTS; - if (isFirst) - log.info("Local maxima: {}", localMaxima); +// if (isFirst) +// log.info("Local maxima: {}", localMaxima); isFirst = false; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java index 74cecae86..db86ad617 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java @@ -54,7 +54,7 @@ public class EvalJsonTest extends BaseDL4JTest { @Test public void testSerde() { - boolean print = true; + boolean print = false; Nd4j.getRandom().setSeed(12345); Evaluation evaluation = new Evaluation(); @@ -105,7 +105,7 @@ public class EvalJsonTest extends BaseDL4JTest { @Test public void testSerdeExactRoc() { Nd4j.getRandom().setSeed(12345); - boolean print = true; + boolean print = false; ROC roc = new ROC(0); ROCBinary roc2 = new ROCBinary(0); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index 0a9cfea4c..812ea2b08 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -131,11 +131,15 @@ public class EvalTest extends BaseDL4JTest { org.nd4j.evaluation.classification.Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test))); checkEvaluationEquality(eval, evalViaMethod); - System.out.println(eval.getConfusionMatrix().toString()); - System.out.println(eval.getConfusionMatrix().toCSV()); - System.out.println(eval.getConfusionMatrix().toHTML()); +// System.out.println(eval.getConfusionMatrix().toString()); +// System.out.println(eval.getConfusionMatrix().toCSV()); +// System.out.println(eval.getConfusionMatrix().toHTML()); +// System.out.println(eval.confusionToString()); - System.out.println(eval.confusionToString()); + eval.getConfusionMatrix().toString(); + eval.getConfusionMatrix().toCSV(); + eval.getConfusionMatrix().toHTML(); + eval.confusionToString(); } private static void assertMapEquals(Map first, Map second) { @@ -161,7 +165,7 @@ public class EvalTest extends BaseDL4JTest { assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix()); } - @Test + @Test(timeout = 300000) public void testEvaluationWithMetaData() throws Exception { RecordReader csv = new CSVRecordReader(); @@ -205,9 +209,10 @@ public class EvalTest extends BaseDL4JTest { e.eval(ds.getLabels(), out, meta); //*** New - evaluate and also store metadata *** } - System.out.println(e.stats()); +// System.out.println(e.stats()); + e.stats(); - System.out.println("\n\n*** Prediction Errors: ***"); +// System.out.println("\n\n*** Prediction Errors: ***"); List errors = e.getPredictionErrors(); //*** New - get list of prediction errors from evaluation *** List metaForErrors = new ArrayList<>(); @@ -219,10 +224,11 @@ public class EvalTest extends BaseDL4JTest { int count = 0; for (org.nd4j.evaluation.meta.Prediction t : errors) { - System.out.println(t + "\t\tRaw Data: " - + csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() //*** New - load subset of data from MetaData object (usually batched for efficiency) *** - + "\tNormalized: " + ds.getFeatures().getRow(count) + "\tLabels: " - + ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count)); + String s = t + "\t\tRaw Data: " + + csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() //*** New - load subset of data from MetaData object (usually batched for efficiency) *** + + "\tNormalized: " + ds.getFeatures().getRow(count) + "\tLabels: " + + ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count); +// System.out.println(s); count++; } @@ -322,9 +328,9 @@ public class EvalTest extends BaseDL4JTest { List l = Arrays.asList(new DataSet(in1, out1, null, lMask1), new DataSet(in2, out2, null, lMask2)); DataSetIterator iter = new ExistingDataSetIterator(l); - System.out.println("Net 1 eval"); +// System.out.println("Net 1 eval"); org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); - System.out.println("Net 2 eval"); +// System.out.println("Net 2 eval"); org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); assertEquals(e1[0], e2[0]); @@ -403,9 +409,9 @@ public class EvalTest extends BaseDL4JTest { List l = Arrays.asList(new DataSet(in1, out1), new DataSet(in2, out2)); DataSetIterator iter = new ExistingDataSetIterator(l); - System.out.println("Eval net 1"); +// System.out.println("Eval net 1"); org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); - System.out.println("Eval net 2"); +// System.out.println("Eval net 2"); org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); assertEquals(e1[0], e2[0]); @@ -470,7 +476,7 @@ public class EvalTest extends BaseDL4JTest { net.setListeners(new EvaluativeListener(iterTest, 3)); - for( int i=0; i<10; i++ ){ + for( int i=0; i<3; i++ ){ net.fit(iter); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java index e9fcc2fff..032b06ed0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java @@ -117,7 +117,7 @@ public class EvaluationToolsTests extends BaseDL4JTest { String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica")); - System.out.println(str); +// System.out.println(str); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java index e3380337b..8a41b614f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java @@ -46,12 +46,6 @@ public class AttentionLayerTest extends BaseDL4JTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - private static final boolean PRINT_RESULTS = true; - private static final boolean RETURN_ON_FIRST_FAILURE = false; - private static final double DEFAULT_EPS = 1e-6; - private static final double DEFAULT_MAX_REL_ERROR = 1e-3; - private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; - @Test public void testSelfAttentionLayer() { int nIn = 3; @@ -104,8 +98,8 @@ public class AttentionLayerTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 100); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) + .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); assertTrue(name, gradOK); } } @@ -165,8 +159,8 @@ public class AttentionLayerTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 100); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) + .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); assertTrue(name, gradOK); } } @@ -226,8 +220,8 @@ public class AttentionLayerTest extends BaseDL4JTest { String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; System.out.println("Starting test: " + name); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 100); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) + .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); assertTrue(name, gradOK); } } @@ -320,8 +314,8 @@ public class AttentionLayerTest extends BaseDL4JTest { net.init(); //System.out.println("Original"); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 100, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) + .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); assertTrue(name, gradOK); } } @@ -383,8 +377,8 @@ public class AttentionLayerTest extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(graph); net.init(); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in}, new INDArray[]{labels}, inMask != null ? new INDArray[]{inMask} : null, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in}) + .labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null).subset(true).maxPerParam(100)); assertTrue(name, gradOK); } } @@ -445,9 +439,8 @@ public class AttentionLayerTest extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(graph); net.init(); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in}, - new INDArray[]{labels}, inMask != null ? new INDArray[]{inMask} : null, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in}) + .labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null)); assertTrue(name, gradOK); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java index 5bafc81b0..eac917e13 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java @@ -56,11 +56,6 @@ import static org.junit.Assert.assertTrue; * */ public class BNGradientCheckTest extends BaseDL4JTest { - private static final boolean PRINT_RESULTS = true; - private static final boolean RETURN_ON_FIRST_FAILURE = false; - private static final double DEFAULT_EPS = 1e-5; - private static final double DEFAULT_MAX_REL_ERROR = 1e-5; - private static final double DEFAULT_MIN_ABS_ERROR = 1e-9; static { Nd4j.setDataType(DataType.DOUBLE); @@ -93,17 +88,15 @@ public class BNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); - if (PRINT_RESULTS) { - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); - } +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); @@ -140,17 +133,15 @@ public class BNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); - if (PRINT_RESULTS) { - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); - } +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); @@ -220,7 +211,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { String name = new Object() { }.getClass().getEnclosingMethod().getName(); - System.out.println("Num params: " + mln.numParams()); +// System.out.println("Num params: " + mln.numParams()); if (doLearningFirst) { //Run a number of iterations of learning @@ -241,20 +232,18 @@ public class BNGradientCheckTest extends BaseDL4JTest { assertTrue(msg, scoreAfter < 0.9 * scoreBefore); } - if (PRINT_RESULTS) { - System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf - + ", outputActivation=" + outputActivation + ", doLearningFirst=" - + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); - for (int k = 0; k < mln.getnLayers(); k++) - System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); - } + System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + + ", outputActivation=" + outputActivation + ", doLearningFirst=" + + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); +// for (int k = 0; k < mln.getnLayers(); k++) +// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 25, excludeParams); //Most params are in output layer, only these should be skipped with this threshold + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).excludeParams(excludeParams).subset(true).maxPerParam(25)); //Most params are in output layer, only these should be skipped with this threshold assertTrue(gradOK); TestUtils.testModelSerialization(mln); @@ -347,20 +336,18 @@ public class BNGradientCheckTest extends BaseDL4JTest { assertTrue(msg, scoreAfter < 0.8 * scoreBefore); } - if (PRINT_RESULTS) { - System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf - + ", outputActivation=" + outputActivation + ", doLearningFirst=" - + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); - for (int k = 0; k < mln.getnLayers(); k++) - System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); - } + System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + + ", outputActivation=" + outputActivation + ", doLearningFirst=" + + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); +// for (int k = 0; k < mln.getnLayers(); k++) +// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); @@ -396,17 +383,15 @@ public class BNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); - if (PRINT_RESULTS) { - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); - } +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); @@ -443,17 +428,15 @@ public class BNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); - if (PRINT_RESULTS) { - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); - } +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); @@ -496,9 +479,8 @@ public class BNGradientCheckTest extends BaseDL4JTest { //i.e., runningMean = decay * runningMean + (1-decay) * batchMean //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("bn_mean", "bn_var")); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{input}, - new INDArray[]{labels}, null, null, excludeParams); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels}).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(net); @@ -585,21 +567,18 @@ public class BNGradientCheckTest extends BaseDL4JTest { assertTrue(msg, scoreAfter < 0.9 * scoreBefore); } - if (PRINT_RESULTS) { - System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf - + ", outputActivation=" + outputActivation + ", doLearningFirst=" - + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); - for (int k = 0; k < net.getNumLayers(); k++) - System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams()); - } + System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + + ", outputActivation=" + outputActivation + ", doLearningFirst=" + + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); +// for (int k = 0; k < net.getNumLayers(); k++) +// System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, - new INDArray[]{input}, new INDArray[]{labels}, null, null, excludeParams); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels}).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(net); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java index a0a109cb1..a2b29e06c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java @@ -108,8 +108,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -188,8 +188,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -272,8 +272,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -349,8 +349,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -414,8 +414,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { INDArray label = TestUtils.randomOneHot(2, finalNOut); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) + .labels(label).inputMask(fm)); assertTrue(s, gradOK); TestUtils.testModelSerialization(net); @@ -509,8 +509,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) + .labels(label).inputMask(fm)); assertTrue(s, gradOK); TestUtils.testModelSerialization(net); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java index 13cc11e80..9a4994d73 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java @@ -144,14 +144,13 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); - for (int j = 0; j < net.getnLayers(); j++) { - log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); - } +// for (int j = 0; j < net.getnLayers(); j++) { +// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// } } - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, - DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, - RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 128); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) + .labels(labels).subset(true).maxPerParam(128)); assertTrue(msg, gradOK); @@ -248,14 +247,13 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); - for (int j = 0; j < net.getnLayers(); j++) { - log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); - } +// for (int j = 0; j < net.getnLayers(); j++) { +// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// } } - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, - DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, - RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 512); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) + .labels(labels).subset(true).maxPerParam(512)); assertTrue(msg, gradOK); @@ -341,9 +339,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); - for (int j = 0; j < net.getnLayers(); j++) { - log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); - } } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, @@ -431,9 +426,9 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); - for (int j = 0; j < net.getnLayers(); j++) { - log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); - } +// for (int j = 0; j < net.getnLayers(); j++) { +// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// } } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, @@ -530,9 +525,9 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); - for (int j = 0; j < net.getnLayers(); j++) { - log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); - } +// for (int j = 0; j < net.getnLayers(); j++) { +// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// } } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, @@ -547,4 +542,92 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { } } } + + @Test + public void testDeconv3d() { + Nd4j.getRandom().setSeed(12345); + // Note: we checked this with a variety of parameters, but it takes a lot of time. + int[] depths = {8, 8, 9}; + int[] heights = {8, 9, 9}; + int[] widths = {8, 8, 9}; + + + int[][] kernels = {{2, 2, 2}, {3, 3, 3}, {2, 3, 2}}; + int[][] strides = {{1, 1, 1}, {1, 1, 1}, {2, 2, 2}}; + + Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.IDENTITY}; + + ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same, ConvolutionMode.Same}; + int[] mbs = {1, 3, 2}; + Convolution3D.DataFormat[] dataFormats = new Convolution3D.DataFormat[]{Convolution3D.DataFormat.NCDHW, Convolution3D.DataFormat.NDHWC, Convolution3D.DataFormat.NCDHW}; + + int convNIn = 2; + int finalNOut = 2; + int[] deconvOut = {2, 3, 4}; + + for (int i = 0; i < activations.length; i++) { + Activation afn = activations[i]; + int miniBatchSize = mbs[i]; + int depth = depths[i]; + int height = heights[i]; + int width = widths[i]; + ConvolutionMode mode = modes[i]; + int[] kernel = kernels[i]; + int[] stride = strides[i]; + Convolution3D.DataFormat df = dataFormats[i]; + int dOut = deconvOut[i]; + + INDArray input; + if (df == Convolution3D.DataFormat.NDHWC) { + input = Nd4j.rand(new int[]{miniBatchSize, depth, height, width, convNIn}); + } else { + input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); + } + INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); + for (int j = 0; j < miniBatchSize; j++) { + labels.putScalar(new int[]{j, j % finalNOut}, 1.0); + } + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) + .updater(new NoOp()) + .weightInit(new NormalDistribution(0, 0.1)) + .list() + .layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel) + .stride(stride).nIn(convNIn).nOut(dOut).hasBias(false) + .convolutionMode(mode).dataFormat(df) + .build()) + .layer(1, new Deconvolution3D.Builder().activation(afn).kernelSize(kernel) + .stride(stride).nOut(dOut).hasBias(false) + .convolutionMode(mode).dataFormat(df) + .build()) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nOut(finalNOut).build()) + .setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); + + String json = conf.toJson(); + MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + assertEquals(conf, c2); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn + + ", kernel = " + Arrays.toString(kernel) + ", stride = " + + Arrays.toString(stride) + ", mode = " + mode.toString() + + ", input depth " + depth + ", input height " + height + + ", input width " + width; + + if (PRINT_RESULTS) { + log.info(msg); + } + + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) + .labels(labels).subset(true).maxPerParam(64)); + + assertTrue(msg, gradOK); + + TestUtils.testModelSerialization(net); + } + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index decb81bb0..329adbc9b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -122,8 +122,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -213,8 +213,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { System.out.println(testName + "- activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -275,8 +275,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); @@ -336,8 +336,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); @@ -346,8 +346,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { //Also check compgraph: ComputationGraph cg = net.toComputationGraph(); - gradOK = GradientCheckUtil.checkGradients(cg, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{input}, new INDArray[]{labels}); + gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(cg).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); assertTrue(msg + " - compgraph", gradOK); TestUtils.testModelSerialization(net); @@ -399,8 +399,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -468,8 +468,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -557,60 +557,52 @@ public class CNNGradientCheckTest extends BaseDL4JTest { int[] minibatchSizes = {2}; int width = 5; int height = 5; - int[] inputDepths = {1, 2, 4}; - - Activation[] activations = {Activation.SIGMOID, Activation.TANH}; Nd4j.getRandom().setSeed(12345); - for (int inputDepth : inputDepths) { - for (Activation afn : activations) { - for (int minibatchSize : minibatchSizes) { - INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); - INDArray labels = Nd4j.zeros(minibatchSize, nOut); - for (int i = 0; i < minibatchSize; i++) { - labels.putScalar(new int[]{i, i % nOut}, 1.0); - } + int[] inputDepths = new int[]{1, 2, 4}; + Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS}; + int[] minibatch = {2, 1, 3}; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp()) - .dataType(DataType.DOUBLE) - .activation(afn) - .list() - .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1) - .padding(0, 0).nIn(inputDepth).nOut(2).build())//output: (5-2+0)/1+1 = 4 - .layer(1, new LocallyConnected2D.Builder().nIn(2).nOut(7).kernelSize(2, 2) - .setInputSize(4, 4).convolutionMode(ConvolutionMode.Strict).hasBias(false) - .stride(1, 1).padding(0, 0).build()) //(4-2+0)/1+1 = 3 - .layer(2, new ConvolutionLayer.Builder().nIn(7).nOut(2).kernelSize(2, 2) - .stride(1, 1).padding(0, 0).build()) //(3-2+0)/1+1 = 2 - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut) - .build()) - .setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); + for( int i=0; i() { + @Override + public void accept(ComputationGraph net) { + Nd4j.getRandom().setSeed(12345); + } + })); - assertTrue(ok); + assertTrue(gradOK); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java index b141323a9..d4e3d3089 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java @@ -92,8 +92,8 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -150,8 +150,8 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println( "testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -209,12 +209,12 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, featuresMask, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).inputMask(featuresMask)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); @@ -292,12 +292,12 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, inputMask, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).inputMask(inputMask)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java index d506bb233..cd3e1d2e3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java @@ -120,8 +120,8 @@ public class GradientCheckTests extends BaseDL4JTest { System.out.println("testMinibatchApplication() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -200,8 +200,8 @@ public class GradientCheckTests extends BaseDL4JTest { System.out.println("testGradientMLP2LayerIrisSimpleRandom() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -294,8 +294,8 @@ public class GradientCheckTests extends BaseDL4JTest { System.out.println("testGradientMLP2LayerIrisSimpleRandom() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l2=" + l2 + ", l1=" + l1); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -339,8 +339,8 @@ public class GradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testEmbeddingLayerSimple"); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -379,8 +379,8 @@ public class GradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testEmbeddingLayerSimple"); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -469,8 +469,8 @@ public class GradientCheckTests extends BaseDL4JTest { + doLearningFirst + ", l2=" + l2 + ", l1=" + l1; if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -539,8 +539,8 @@ public class GradientCheckTests extends BaseDL4JTest { // expectation in case linear regression(with only element wise multiplication layer): large weight for the fourth weight log.info("params after learning: " + netGraph.getLayer(1).paramTable()); - boolean gradOK = checkGradients(netGraph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{features}, new INDArray[]{labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(netGraph).inputs(new INDArray[]{features}) + .labels(new INDArray[]{labels})); msg = "elementWiseMultiplicationLayerTest() - activationFn=" + "ID" + ", lossFn=" + "Cos-sim" + ", outputActivation=" + "Id" + ", doLearningFirst=" + "true"; @@ -592,8 +592,8 @@ public class GradientCheckTests extends BaseDL4JTest { } String msg = "mask=" + maskArray + ", inputRank=" + inputRank; - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, label, fMask, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) + .labels(label).inputMask(fMask)); assertTrue(msg, gradOK); TestUtils.testModelSerialization(net); @@ -767,8 +767,8 @@ public class GradientCheckTests extends BaseDL4JTest { System.out.println("testGradientMLP2LayerIrisSimpleRandom() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", layerNorm=" + layerNorm); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java index 623158c68..b702520e4 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java @@ -103,13 +103,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testBasicIris()"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); String msg = "testBasicIris()"; assertTrue(msg, gradOK); @@ -155,13 +154,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testBasicIrisWithMerging()"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); String msg = "testBasicIrisWithMerging()"; assertTrue(msg, gradOK); @@ -213,13 +211,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")"; assertTrue(msg, gradOK); @@ -274,13 +271,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")"; assertTrue(msg, gradOK); @@ -328,9 +324,8 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { graph.fit(new DataSet(in, labels)); - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in}, - new INDArray[]{labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in}) + .labels(new INDArray[]{labels})); assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); } @@ -372,13 +367,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testCnnDepthMerge()"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); String msg = "testCnnDepthMerge()"; assertTrue(msg, gradOK); @@ -430,13 +424,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMWithMerging()"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); String msg = "testLSTMWithMerging()"; assertTrue(msg, gradOK); @@ -466,13 +459,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMWithSubset()"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); String msg = "testLSTMWithSubset()"; assertTrue(msg, gradOK); @@ -504,26 +496,24 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMWithLastTimeStepVertex()"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } //First: test with no input mask array - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); String msg = "testLSTMWithLastTimeStepVertex()"; assertTrue(msg, gradOK); //Second: test with input mask arrays. - INDArray inMask = Nd4j.zeros(3, 5); - inMask.putRow(0, Nd4j.create(new double[] {1, 1, 1, 0, 0})); - inMask.putRow(1, Nd4j.create(new double[] {1, 1, 1, 1, 0})); - inMask.putRow(2, Nd4j.create(new double[] {1, 1, 1, 1, 1})); - graph.setLayerMaskArrays(new INDArray[] {inMask}, null); - gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, - PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, new INDArray[] {labels}); + INDArray inMask = Nd4j.zeros(3, 4); + inMask.putRow(0, Nd4j.create(new double[] {1, 1, 0, 0})); + inMask.putRow(1, Nd4j.create(new double[] {1, 1, 1, 0})); + inMask.putRow(2, Nd4j.create(new double[] {1, 1, 1, 1})); + gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels}).inputMask(new INDArray[]{inMask})); assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); @@ -566,13 +556,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMWithDuplicateToTimeSeries()"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input1, input2}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input1, input2}) + .labels(new INDArray[]{labels})); String msg = "testLSTMWithDuplicateToTimeSeries()"; assertTrue(msg, gradOK); @@ -615,13 +604,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMWithReverseTimeSeriesVertex()"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); String msg = "testLSTMWithDuplicateToTimeSeries()"; assertTrue(msg, gradOK); @@ -632,8 +620,8 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { inMask.putRow(1, Nd4j.create(new double[] {1, 1, 0, 1, 0})); inMask.putRow(2, Nd4j.create(new double[] {1, 1, 1, 1, 1})); graph.setLayerMaskArrays(new INDArray[] {inMask}, null); - gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, - PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, new INDArray[] {labels}); + gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); @@ -671,13 +659,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { String msg = "testMultipleInputsLayer() - minibatchSize = " + mb; if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, inputs, - new INDArray[] {out}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(inputs) + .labels(new INDArray[]{out})); assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); @@ -712,13 +699,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { String msg = "testMultipleOutputsLayer() - minibatchSize = " + mb; if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {out}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{out})); assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); @@ -759,12 +745,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { String msg = "testMultipleOutputsMergeVertex() - minibatchSize = " + mb; if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, new INDArray[] {out}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(input) + .labels(new INDArray[]{out})); assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); @@ -810,13 +796,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { String msg = "testMultipleOutputsMergeVertex() - minibatchSize = " + mb; if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, - new INDArray[] {out}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) + .labels(new INDArray[]{out})); assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); @@ -873,19 +858,18 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { Map out = graph.feedForward(new INDArray[] {pos, anc, neg}, true); - for (String s : out.keySet()) { - System.out.println(s + "\t" + Arrays.toString(out.get(s).shape())); - } +// for (String s : out.keySet()) { +// System.out.println(s + "\t" + Arrays.toString(out.get(s).shape())); +// } if (PRINT_RESULTS) { System.out.println("testBasicIrisTripletStackingL2Loss()"); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {pos, anc, neg}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{pos, anc, neg}) + .labels(new INDArray[]{labels})); String msg = "testBasicIrisTripletStackingL2Loss()"; assertTrue(msg, gradOK); @@ -941,13 +925,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { String msg = "testBasicCenterLoss() - lambda = " + lambda + ", trainFirst = " + train; if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {example}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{example}) + .labels(new INDArray[]{labels})); assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); @@ -1007,8 +990,8 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { String msg = "testBasicCenterLoss() - trainFirst = " + train; if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -1056,13 +1039,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1, in2}, - new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) + .labels(new INDArray[]{labels})); assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); @@ -1115,13 +1097,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1, in2}, - new INDArray[] {labels1, labels2}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) + .labels(new INDArray[]{labels1, labels2})); assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); @@ -1174,13 +1155,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1, in2}, - new INDArray[] {labels1, labels2}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) + .labels(new INDArray[]{labels1, labels2})); assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); @@ -1238,15 +1218,14 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } graph.setLayerMaskArrays(new INDArray[] {inMask1, inMask2}, null); - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1, in2}, - new INDArray[] {labels1, labels2}, new INDArray[] {inMask1, inMask2}, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) + .labels(new INDArray[]{labels1, labels2}).inputMask(new INDArray[]{inMask1, inMask2})); assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); @@ -1298,13 +1277,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1, in2}, - new INDArray[] {labels1, labels2}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) + .labels(new INDArray[]{labels1, labels2})); assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); } @@ -1341,13 +1319,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1}, - new INDArray[] {labels1}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1}) + .labels(new INDArray[]{labels1})); assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); @@ -1391,13 +1368,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < graph.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// for (int j = 0; j < graph.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1}, - new INDArray[] {labels1}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1}) + .labels(new INDArray[]{labels1})); assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); @@ -1430,12 +1406,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testGraphEmbeddingLayerSimple"); - for (int j = 0; j < cg.getNumLayers(); j++) - System.out.println("Layer " + j + " # params: " + cg.getLayer(j).numParams()); +// for (int j = 0; j < cg.getNumLayers(); j++) +// System.out.println("Layer " + j + " # params: " + cg.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(cg, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, - PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, new INDArray[] {labels}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(cg).inputs(new INDArray[]{input}) + .labels(new INDArray[]{labels})); String msg = "testGraphEmbeddingLayerSimple"; assertTrue(msg, gradOK); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java index c1e97a385..c0f5a2573 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java @@ -51,10 +51,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*; public class GradientCheckTestsMasking extends BaseDL4JTest { private static final boolean PRINT_RESULTS = true; - private static final boolean RETURN_ON_FIRST_FAILURE = false; - private static final double DEFAULT_EPS = 1e-6; - private static final double DEFAULT_MAX_REL_ERROR = 1e-3; - private static final double DEFAULT_MIN_ABS_ERROR = 1e-7; static { Nd4j.setDataType(DataType.DOUBLE); @@ -130,8 +126,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, maskArr); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).labelMask(maskArr)); String msg = "gradientCheckMaskingOutputSimple() - timeSeriesLength=" + timeSeriesLength + ", miniBatchSize=" + 1; @@ -146,9 +142,9 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345L); int timeSeriesLength = 5; - int nIn = 5; + int nIn = 3; int layerSize = 3; - int nOut = 3; + int nOut = 2; int miniBatchSize = 2; @@ -174,24 +170,16 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - Random r = new Random(12345L); INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}, 'f').subi(0.5); - INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength); - for (int i = 0; i < miniBatchSize; i++) { - for (int j = 0; j < nIn; j++) { - labels.putScalar(i, r.nextInt(nOut), j, 1.0); - } - } + INDArray labels = TestUtils.randomOneHotTimeSeries(miniBatchSize, nOut, timeSeriesLength); if (PRINT_RESULTS) { System.out.println("testBidirectionalLSTMMasking() - testNum = " + testNum++); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, mask, mask, true, 16); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).inputMask(mask).labelMask(mask).subset(true).maxPerParam(12)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); @@ -271,8 +259,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { System.out.println(msg); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, features, labels, null, labelMask); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(features) + .labels(labels).labelMask(labelMask)); assertTrue(msg, gradOK); TestUtils.testModelSerialization(net); @@ -366,8 +354,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { System.out.println(msg); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, features, labels, null, labelMask); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(features) + .labels(labels).labelMask(labelMask)); assertTrue(msg, gradOK); @@ -387,9 +375,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(cg); graph.init(); - gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, - new INDArray[] {features}, new INDArray[] {labels}, null, new INDArray[]{labelMask}, null); + gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{features}) + .labels(new INDArray[]{labels}).labelMask(new INDArray[]{labelMask})); assertTrue(msg + " (compgraph)", gradOK); TestUtils.testModelSerialization(graph); @@ -425,8 +412,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { assertTrue(lm.sumNumber().intValue() > 0); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, lm); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) + .labels(l).labelMask(lm)); assertTrue(gradOK); //Also ensure score doesn't depend on masked feature or label values @@ -478,9 +465,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { assertTrue(lm.sumNumber().intValue() > 0); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{f}, new INDArray[]{l}, - null, new INDArray[]{lm}); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{f}) + .labels(new INDArray[]{l}).labelMask(new INDArray[]{lm})); assertTrue(gradOK); //Also ensure score doesn't depend on masked feature or label values diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java index 18fbcce45..a2bb1989c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java @@ -82,10 +82,10 @@ public class LRNGradientCheckTests extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); - if (PRINT_RESULTS) { - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); - } +// if (PRINT_RESULTS) { +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java index caa52f4c9..1e673b936 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java @@ -124,8 +124,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { String testName = "testLSTMBasic(" + (graves ? "GravesLSTM" : "LSTM") + ")"; if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -213,12 +213,12 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { + outputActivation + ", l2=" + l2 + ", l1=" + l1; if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 128); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).subset(true).maxPerParam(128)); assertTrue(testName, gradOK); TestUtils.testModelSerialization(mln); @@ -341,8 +341,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { System.out.println("testGradientGravesBidirectionalLSTMFull() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", l2=" + l2 + ", l1=" + l1); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -394,8 +394,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 128); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).subset(true).maxPerParam(128)); String msg = "testGradientGravesLSTMEdgeCases() - timeSeriesLength=" + timeSeriesLength[i] + ", miniBatchSize=" + miniBatchSize[i]; @@ -452,8 +452,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { System.out.println("layer " + i + "\t" + mln.getLayer(i).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 32); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).subset(true).maxPerParam(32)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index fa06ff8f7..632a85e22 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -206,21 +206,19 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { } else { failed.add(testName); } - - System.out.println("\n\n"); TestUtils.testModelSerialization(net); } } - - System.out.println("---- Passed ----"); - for (String s : passed) { - System.out.println(s); - } - - System.out.println("---- Failed ----"); - for (String s : failed) { - System.out.println(s); + if(failed.size() > 0) { + System.out.println("---- Passed ----"); + for (String s : passed) { + System.out.println(s); + } + System.out.println("---- Failed ----"); + for (String s : failed) { + System.out.println(s); + } } assertEquals("Tests failed", 0, failed.size()); @@ -376,7 +374,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { failed.add(testName); } - System.out.println("\n\n"); TestUtils.testModelSerialization(net); } } @@ -684,8 +681,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { } else { failed.add(testName); } - - System.out.println("\n\n"); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index 32a229101..67fc4c11c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -136,13 +136,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { String testName = "testRnnLossLayer(lf=" + lf + ", maskType=" + mt + ", outputActivation = " + oa + ")"; if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } System.out.println("Starting test: " + testName); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, labelMask); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).labelMask(labelMask)); assertTrue(testName, gradOK); TestUtils.testModelSerialization(mln); @@ -243,13 +243,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { String testName = "testCnnLossLayer(lf=" + lf + ", maskType=" + mt + ", outputActivation = " + oa + ")"; if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } System.out.println("Starting test: " + testName); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, labelMask); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).labelMask(labelMask)); assertTrue(testName, gradOK); TestUtils.testModelSerialization(mln); @@ -392,13 +392,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { String testName = "testCnn3dLossLayer(dataFormat=" + dataFormat + ",lf=" + lf + ", maskType=" + mt + ", outputActivation = " + oa + ")"; if (PRINT_RESULTS) { System.out.println(testName); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } System.out.println("Starting test: " + testName); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, labelMask); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) + .labels(labels).labelMask(labelMask)); assertTrue(testName, gradOK); TestUtils.testModelSerialization(mln); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java index 98385de17..2980cad7c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java @@ -127,8 +127,8 @@ public class RnnGradientChecks extends BaseDL4JTest { net.init(); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) + .labels(labels).inputMask(inMask)); assertTrue(gradOK); @@ -207,8 +207,8 @@ public class RnnGradientChecks extends BaseDL4JTest { net.init(); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) + .labels(labels).inputMask(inMask)); assertTrue(gradOK); TestUtils.testModelSerialization(net); } @@ -282,8 +282,8 @@ public class RnnGradientChecks extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 16); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) + .labels(labels).inputMask(inMask).subset(true).maxPerParam(16)); assertTrue(name, gradOK); TestUtils.testModelSerialization(net); } @@ -346,8 +346,8 @@ public class RnnGradientChecks extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 16); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) + .labels(labels).inputMask(inMask).subset(true).maxPerParam(16)); assertTrue(name, gradOK); TestUtils.testModelSerialization(net); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java index 8349b732d..2d889a6a1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java @@ -182,9 +182,9 @@ public class UtilLayerGradientChecks extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, label, inMask, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) + .minAbsoluteError(1e-7) + .labels(label).inputMask(inMask)); assertTrue(gradOK); TestUtils.testModelSerialization(net); @@ -223,9 +223,8 @@ public class UtilLayerGradientChecks extends BaseDL4JTest { Set excludeParams = new HashSet<>(); excludeParams.addAll(Arrays.asList("1_W", "1_b", "2_W", "2_b")); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, null, null, - false, -1, excludeParams); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) + .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(net); @@ -234,9 +233,8 @@ public class UtilLayerGradientChecks extends BaseDL4JTest { //Test ComputationGraph equivalent: ComputationGraph g = net.toComputationGraph(); - boolean gradOKCG = GradientCheckUtil.checkGradients(g, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in}, new INDArray[]{labels}, - null, null, excludeParams); + boolean gradOKCG = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(g).inputs(new INDArray[]{in}) + .labels(new INDArray[]{labels}).excludeParams(excludeParams)); assertTrue(gradOKCG); TestUtils.testModelSerialization(g); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java index cbf662987..6d1903579 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java @@ -46,7 +46,7 @@ import static org.junit.Assert.assertTrue; */ public class VaeGradientCheckTests extends BaseDL4JTest { - private static final boolean PRINT_RESULTS = true; + private static final boolean PRINT_RESULTS = false; private static final boolean RETURN_ON_FIRST_FAILURE = false; private static final double DEFAULT_EPS = 1e-6; private static final double DEFAULT_MAX_REL_ERROR = 1e-3; @@ -122,8 +122,8 @@ public class VaeGradientCheckTests extends BaseDL4JTest { + Arrays.toString(decoderSizes) + ", l2=" + l2 + ", l1=" + l1; if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -193,8 +193,8 @@ public class VaeGradientCheckTests extends BaseDL4JTest { + l1; if (PRINT_RESULTS) { System.out.println(msg); - for (int l = 0; l < mln.getnLayers(); l++) - System.out.println("Layer " + l + " # params: " + mln.getLayer(l).numParams()); +// for (int l = 0; l < mln.getnLayers(); l++) +// System.out.println("Layer " + l + " # params: " + mln.getLayer(l).numParams()); } boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS, @@ -281,8 +281,8 @@ public class VaeGradientCheckTests extends BaseDL4JTest { String msg = "testVaePretrainReconstructionDistributions() - " + reconstructionDistributions[i]; if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS, @@ -323,8 +323,8 @@ public class VaeGradientCheckTests extends BaseDL4JTest { String msg = "testVaePretrainMultipleSamples() - numSamples = " + numSamples; if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS, diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java index 0b95dc3b6..147150aa8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java @@ -120,8 +120,8 @@ public class YoloGradientCheckTests extends BaseDL4JTest { String msg = "testYoloOutputLayer() - minibatch = " + mb + ", w=" + w + ", h=" + h + ", l1=" + l1[i] + ", l2=" + l2[i]; System.out.println(msg); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 100); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) + .labels(labels).subset(true).maxPerParam(100)); assertTrue(msg, gradOK); TestUtils.testModelSerialization(net); @@ -228,8 +228,8 @@ public class YoloGradientCheckTests extends BaseDL4JTest { INDArray f = ds.getFeatures(); INDArray l = ds.getLabels(); - boolean ok = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, null, true, 64); + boolean ok = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) + .labels(l).inputMask(null).subset(true).maxPerParam(64)); assertTrue(ok); TestUtils.testModelSerialization(net); diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/adapters/ArgmaxAdapterTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/adapters/ArgmaxAdapterTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/adapters/ArgmaxAdapterTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/adapters/ArgmaxAdapterTest.java diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/adapters/Regression2dAdapterTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/adapters/Regression2dAdapterTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/adapters/Regression2dAdapterTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/adapters/Regression2dAdapterTest.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java index 06f4d36b7..cf972bac3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java @@ -130,7 +130,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { .setOutputs("out").build(); String json = conf.toJson(); - System.out.println(json); +// System.out.println(json); ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json); @@ -258,7 +258,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { .addVertex("test2", new StaticInnerGraphVertex(4, 5), "in").setOutputs("test", "test2").build(); String json = conf.toJson(); - System.out.println(json); +// System.out.println(json); ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java index 1f6bc9816..8e726b869 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java @@ -54,7 +54,7 @@ public class CustomPreprocessorTest extends BaseDL4JTest { String json = conf.toJson(); String yaml = conf.toYaml(); - System.out.println(json); +// System.out.println(json); MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); assertEquals(conf, confFromJson); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index 013738476..d9da12b62 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -99,6 +99,11 @@ public class DTypeTests extends BaseDL4JTest { Convolution1D.class //Alias for Convolution1DLayer )); + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + @AfterClass public static void after() { ImmutableSet info; @@ -545,6 +550,7 @@ public class DTypeTests extends BaseDL4JTest { .layer(new Convolution3D.Builder().kernelSize(2, 2, 2).stride(1, 1, 1).nOut(3).activation(Activation.TANH).build()) .layer(new Convolution3D.Builder().kernelSize(2, 2, 2).stride(1, 1, 1).nOut(3).activation(Activation.TANH).build()) .layer(new Subsampling3DLayer.Builder().poolingType(PoolingType.AVG).kernelSize(2, 2, 2).stride(2, 2, 2).build()) + .layer(new Deconvolution3D.Builder().kernelSize(2,2,2).stride(1,1,1).nIn(3).nOut(3).activation(Activation.TANH).build()) .layer(new Cropping3D.Builder(1, 1, 1, 1, 1, 1).build()) .layer(new ZeroPadding3DLayer.Builder(1, 1, 1, 1, 1, 1).build()) .layer(new ActivationLayer(Activation.LEAKYRELU)) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index 3e330d248..28bc42983 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -317,7 +317,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertEquals(paramsMLN, paramsGraph); } - @Test + @Test(timeout = 300000) public void testIrisFitMultiDataSetIterator() throws Exception { RecordReader rr = new CSVRecordReader(0, ','); @@ -531,28 +531,38 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(conf1); graph.init(); - System.out.println(graph.summary()); - System.out.println(graph.summary(InputType.feedForward(5))); +// System.out.println(graph.summary()); +// System.out.println(graph.summary(InputType.feedForward(5))); + graph.summary(); + graph.summary(InputType.feedForward(5)); graph = new ComputationGraph(conf2); graph.init(); - System.out.println(graph.summary()); - System.out.println(graph.summary(InputType.recurrent(5))); +// System.out.println(graph.summary()); +// System.out.println(graph.summary(InputType.recurrent(5))); + graph.summary(); + graph.summary(InputType.recurrent(5)); graph = new ComputationGraph(conf3); graph.init(); - System.out.println(graph.summary()); - System.out.println(graph.summary(InputType.convolutional(28, 28, 1))); +// System.out.println(graph.summary()); +// System.out.println(graph.summary(InputType.convolutional(28, 28, 1))); + graph.summary(); + graph.summary(InputType.convolutional(28, 28, 1)); graph = new ComputationGraph(conf4); graph.init(); - System.out.println(graph.summary()); - System.out.println(graph.summary(InputType.convolutional(28, 28, 1), InputType.recurrent(5))); +// System.out.println(graph.summary()); +// System.out.println(graph.summary(InputType.convolutional(28, 28, 1), InputType.recurrent(5))); + graph.summary(); + graph.summary(InputType.convolutional(28, 28, 1), InputType.recurrent(5)); graph = new ComputationGraph(conf5); graph.init(); - System.out.println(graph.summary()); - System.out.println(graph.summary(InputType.convolutional(28, 28, 1))); +// System.out.println(graph.summary()); +// System.out.println(graph.summary(InputType.convolutional(28, 28, 1))); + graph.summary(); + graph.summary(InputType.convolutional(28, 28, 1)); } @Test @@ -753,7 +763,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { int nOut = 3; for(WorkspaceMode ws : WorkspaceMode.values()) { - System.out.println("***** WORKSPACE: " + ws); +// System.out.println("***** WORKSPACE: " + ws); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() .updater(new Adam(0.01)) @@ -981,7 +991,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { OptimizationAlgorithm.LBFGS}; for (OptimizationAlgorithm oa : oas) { - System.out.println(oa); +// System.out.println(oa); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(oa).graphBuilder() .addInputs("input") @@ -1065,12 +1075,15 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph modelToTune = new ComputationGraph(conf); modelToTune.init(); - System.out.println(modelToTune.summary()); +// System.out.println(modelToTune.summary()); + modelToTune.summary(); ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre2").build(); - System.out.println(modelNow.summary()); - System.out.println(modelNow.summary(InputType.feedForward(10),InputType.feedForward(2))); +// System.out.println(modelNow.summary()); +// System.out.println(modelNow.summary(InputType.feedForward(10),InputType.feedForward(2))); + modelNow.summary(); + modelNow.summary(InputType.feedForward(10),InputType.feedForward(2)); } @Test @@ -1315,9 +1328,12 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph modelExpectedArch = new ComputationGraph(confForArchitecture); modelExpectedArch.init(); ComputationGraph modelMow = new TransferLearning.GraphBuilder(modelExpectedArch).setFeatureExtractor("layer2").build(); - System.out.println(modelExpectedArch.summary()); - System.out.println(modelMow.summary()); - System.out.println(modelExpectedArch.summary(InputType.recurrent(V_HEIGHT* V_WIDTH* 3))); +// System.out.println(modelExpectedArch.summary()); +// System.out.println(modelMow.summary()); +// System.out.println(modelExpectedArch.summary(InputType.recurrent(V_HEIGHT* V_WIDTH* 3))); + modelExpectedArch.summary(); + modelMow.summary(); + modelExpectedArch.summary(InputType.recurrent(V_HEIGHT* V_WIDTH* 3)); } @Test @@ -2117,8 +2133,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { INDArray features = Nd4j.rand(new int[] {dataSize, inputSize}); INDArray labels = Nd4j.rand(new int[] {dataSize, outputSize}); - boolean gradOK = GradientCheckUtil.checkGradients(net, 1e-6, 1e-3, - 1e-8, false, true, new INDArray[]{features}, new INDArray[]{labels}, null, null); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{features}) + .labels(new INDArray[]{labels})); assertTrue(gradOK); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java index dee8efac4..69b15951e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java @@ -53,7 +53,7 @@ public class TestCustomActivation extends BaseDL4JTest { String json = conf.toJson(); String yaml = conf.toYaml(); - System.out.println(json); +// System.out.println(json); MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); assertEquals(conf, confFromJson); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java index f5b131600..5ead0e4b1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java @@ -64,7 +64,7 @@ public class TestCustomLayers extends BaseDL4JTest { String json = conf.toJson(); String yaml = conf.toYaml(); - System.out.println(json); +// System.out.println(json); MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); assertEquals(conf, confFromJson); @@ -88,7 +88,7 @@ public class TestCustomLayers extends BaseDL4JTest { String json = conf.toJson(); String yaml = conf.toYaml(); - System.out.println(json); +// System.out.println(json); ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json); assertEquals(conf, confFromJson); @@ -135,7 +135,7 @@ public class TestCustomLayers extends BaseDL4JTest { String json = conf.toJson(); String yaml = conf.toYaml(); - System.out.println(json); +// System.out.println(json); MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); assertEquals(conf, confFromJson); @@ -188,7 +188,7 @@ public class TestCustomLayers extends BaseDL4JTest { String json = conf.toJson(); String yaml = conf.toYaml(); - System.out.println(json); +// System.out.println(json); ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json); assertEquals(conf, confFromJson); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java index 20a6b34cf..972302d85 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java @@ -35,6 +35,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -156,7 +157,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) .build(); MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).build()) + .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) .build(); @@ -204,7 +205,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) .build(); MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).build()) + .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) .build(); @@ -249,8 +250,8 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .build(); MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH) .weightInit(WeightInit.XAVIER).list() - .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).build()).layer(1, - new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(4) + .layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(4) .activation(Activation.SOFTMAX).build()) .build(); @@ -309,7 +310,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) .build(); MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).build()) + .layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).activation(Activation.IDENTITY).build()) .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) .setInputType(InputType.recurrent(nClassesIn)) .build(); @@ -344,7 +345,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.computeGradientAndScore(); net2.computeGradientAndScore(); - System.out.println(net.score() + "\t" + net2.score()); +// System.out.println(net.score() + "\t" + net2.score()); assertEquals(net2.score(), net.score(), 1e-6); Map gradient = net.gradient().gradientForVariable(); @@ -375,7 +376,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .weightInit(WeightInit.XAVIER) .dataType(DataType.DOUBLE) .list() - .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).build()) + .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) .layer(1, new GravesLSTM.Builder().nIn(5).nOut(7).activation(Activation.SOFTSIGN).build()) .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(7).nOut(4) .activation(Activation.SOFTMAX).build()) @@ -416,7 +417,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.computeGradientAndScore(); net2.computeGradientAndScore(); - System.out.println(net.score() + "\t" + net2.score()); +// System.out.println(net.score() + "\t" + net2.score()); assertEquals(net2.score(), net.score(), 1e-5); Map gradient = net.gradient().gradientForVariable(); @@ -513,7 +514,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.computeGradientAndScore(); net2.computeGradientAndScore(); - System.out.println(net.score() + "\t" + net2.score()); +// System.out.println(net.score() + "\t" + net2.score()); assertEquals(net2.score(), net.score(), 1e-5); Map gradients = net.gradient().gradientForVariable(); @@ -707,4 +708,21 @@ public class EmbeddingLayerTest extends BaseDL4JTest { return true; } } + + @Test + public void testEmbeddingDefaultActivation(){ + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .list() + .layer(new EmbeddingLayer.Builder().nIn(10).nOut(10).build()) + .layer(new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build()) + .build(); + + EmbeddingLayer l = (EmbeddingLayer) conf.getConf(0).getLayer(); + assertEquals(new ActivationIdentity(), l.getActivationFn()); + + EmbeddingSequenceLayer l2 = (EmbeddingSequenceLayer) conf.getConf(1).getLayer(); + assertEquals(new ActivationIdentity(), l2.getActivationFn()); + + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java index 2acb555a6..2c88cef3c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java @@ -90,6 +90,11 @@ public class BatchNormalizationTest extends BaseDL4JTest { public void doBefore() { } + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + @Test public void testDnnForwardPass() { int nOut = 10; @@ -102,7 +107,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { INDArray mean = output.mean(0); INDArray stdev = output.std(false, 0); - System.out.println(Arrays.toString(mean.data().asFloat())); +// System.out.println(Arrays.toString(mean.data().asFloat())); assertArrayEquals(new float[nOut], mean.data().asFloat(), 1e-6f); assertEquals(Nd4j.ones(nOut), stdev); @@ -161,8 +166,8 @@ public class BatchNormalizationTest extends BaseDL4JTest { INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces()); - System.out.println(Arrays.toString(outExpected.data().asDouble())); - System.out.println(Arrays.toString(out.data().asDouble())); +// System.out.println(Arrays.toString(outExpected.data().asDouble())); +// System.out.println(Arrays.toString(out.data().asDouble())); assertEquals(outExpected, out); @@ -190,9 +195,9 @@ public class BatchNormalizationTest extends BaseDL4JTest { assertEquals(dldgammaExp, dldgamma); assertEquals(dldbetaExp, dldbeta); - System.out.println("EPSILONS"); - System.out.println(Arrays.toString(dldinExp.data().asDouble())); - System.out.println(Arrays.toString(p.getSecond().dup().data().asDouble())); +// System.out.println("EPSILONS"); +// System.out.println(Arrays.toString(dldinExp.data().asDouble())); +// System.out.println(Arrays.toString(p.getSecond().dup().data().asDouble())); assertEquals(dldinExp, p.getSecond()); } @@ -303,8 +308,8 @@ public class BatchNormalizationTest extends BaseDL4JTest { INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces()); - System.out.println(Arrays.toString(outExpected.data().asDouble())); - System.out.println(Arrays.toString(out.data().asDouble())); +// System.out.println(Arrays.toString(outExpected.data().asDouble())); +// System.out.println(Arrays.toString(out.data().asDouble())); assertEquals(outExpected, out); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java index 0a004bbae..14f3ee6c0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java @@ -140,7 +140,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { y2impl.setLabels(labels); double score = y2impl.computeScore(0.0, true, LayerWorkspaceMgr.noWorkspaces()); - System.out.println("SCORE: " + score); +// System.out.println("SCORE: " + score); assertTrue(score > 0.0); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java index a0fc0f99d..66b4c8eab 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java @@ -220,20 +220,20 @@ public class GravesLSTMTest extends BaseDL4JTest { INDArray out1 = net.output(in1); INDArray out2 = net.output(in2); - System.out.println(Arrays.toString(net.output(in1).data().asFloat())); - System.out.println(Arrays.toString(net.output(in2).data().asFloat())); +// System.out.println(Arrays.toString(net.output(in1).data().asFloat())); +// System.out.println(Arrays.toString(net.output(in2).data().asFloat())); List activations1 = net.feedForward(in1); List activations2 = net.feedForward(in2); - for (int i = 0; i < 3; i++) { - System.out.println("-----\n" + i); - System.out.println(Arrays.toString(activations1.get(i).dup().data().asDouble())); - System.out.println(Arrays.toString(activations2.get(i).dup().data().asDouble())); - - System.out.println(activations1.get(i)); - System.out.println(activations2.get(i)); - } +// for (int i = 0; i < 3; i++) { +// System.out.println("-----\n" + i); +// System.out.println(Arrays.toString(activations1.get(i).dup().data().asDouble())); +// System.out.println(Arrays.toString(activations2.get(i).dup().data().asDouble())); +// +// System.out.println(activations1.get(i)); +// System.out.println(activations2.get(i)); +// } diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java similarity index 100% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java index d45195870..317dca24d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java @@ -306,8 +306,8 @@ public class TestSameDiffConv extends BaseDL4JTest { INDArray l = TestUtils.randomOneHot(minibatch, nOut); log.info("Starting: " + msg); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, null, true, 50); //Most of weights are in output layer + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) + .labels(l).subset(true).maxPerParam(50)); assertTrue(msg, gradOK); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java index 7f9a54f8e..4e923bf4a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java @@ -135,7 +135,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest { assertEquals(gStd.gradient(), gSD.gradient()); - System.out.println("========================================================================"); +// System.out.println("========================================================================"); //Sanity check: different minibatch size in = Nd4j.rand(2 * minibatch, nIn); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java index 6a88cc550..aa7527841 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java @@ -317,7 +317,7 @@ public class TestReconstructionDistributions extends BaseDL4JTest { INDArray gradient = rd.gradient(x, distributionParams); String testName = "minibatch = " + minibatch + ", size = " + inputSize + ", Distribution = " + rd; - System.out.println("\n\n***** Starting test: " + testName + "*****"); + System.out.println("***** Starting test: " + testName + "*****"); int totalFailureCount = 0; for (int i = 0; i < distributionParams.size(1); i++) { @@ -349,7 +349,7 @@ public class TestReconstructionDistributions extends BaseDL4JTest { totalFailureCount++; } } else { - log.info("Input (" + j + "," + i + ") passed: grad= " + backpropGrad + ", numericalGrad= " + log.trace("Input (" + j + "," + i + ") passed: grad= " + backpropGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java index 077173a5a..d2bf06a56 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java @@ -472,7 +472,7 @@ public class WorkspaceTests extends BaseDL4JTest { final ComputationGraph computationGraph = new ComputationGraph(config); computationGraph.init(); - computationGraph.setListeners(new ScoreIterationListener(1)); + computationGraph.setListeners(new ScoreIterationListener(3)); WSTestDataSetIterator iterator = new WSTestDataSetIterator(); computationGraph.fit(iterator); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java index af15f3b45..e8236bf01 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java @@ -66,7 +66,7 @@ public class BackPropMLPTest extends BaseDL4JTest { public void testMLP() { //Simple mini-batch test with multiple hidden layers MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 4, 3}, Activation.SIGMOID); - System.out.println(conf); +// System.out.println(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); DataSetIterator iter = new IrisDataSetIterator(10, 100); @@ -80,7 +80,7 @@ public class BackPropMLPTest extends BaseDL4JTest { public void testMLP2() { //Simple mini-batch test with multiple hidden layers MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 15, 3}, Activation.TANH); - System.out.println(conf); +// System.out.println(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); @@ -104,7 +104,7 @@ public class BackPropMLPTest extends BaseDL4JTest { Layer[] layers = network.getLayers(); - final boolean printCalculations = true; + final boolean printCalculations = false; while (iris.hasNext()) { DataSet data = iris.next(); @@ -212,7 +212,7 @@ public class BackPropMLPTest extends BaseDL4JTest { assertEquals(l1BiasFloatAfter,expectedL1BiasAfter,eps); assertArrayEquals(l2BiasFloatAfter,expectedL2BiasAfter,eps); */ - System.out.println("\n\n--------------"); +// System.out.println("\n\n--------------"); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index 73ebf1ccd..ac1656e92 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -922,9 +922,9 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(confForArchitecture); modelExpectedArch.init(); MultiLayerNetwork modelMow = new TransferLearning.Builder(modelExpectedArch).setFeatureExtractor(2).build(); - System.out.println(modelExpectedArch.summary()); - System.out.println(modelMow.summary()); - System.out.println(modelMow.summary(InputType.recurrent(V_HEIGHT*V_WIDTH*3))); +// System.out.println(modelExpectedArch.summary()); +// System.out.println(modelMow.summary()); +// System.out.println(modelMow.summary(InputType.recurrent(V_HEIGHT*V_WIDTH*3))); } @Test(expected = DL4JException.class) @@ -1149,7 +1149,7 @@ public class MultiLayerTest extends BaseDL4JTest { int nOut = 3; for(WorkspaceMode ws : WorkspaceMode.values()) { - System.out.println("***** WORKSPACE: " + ws); +// System.out.println("***** WORKSPACE: " + ws); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .updater(new Adam(0.01)) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java index 93e9bb9c7..5da79bc58 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java @@ -570,8 +570,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { for (int j = 0; j < expOut.size(); j++) { INDArray exp = expOut.get(j); INDArray act = outSlice.get(j); - System.out.println(j); - System.out.println(exp.sub(act)); +// System.out.println(j); +// System.out.println(exp.sub(act)); assertEquals(exp, act); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java index 959ccbd22..2feb7792c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java @@ -219,10 +219,10 @@ public class TestVariableLengthTS extends BaseDL4JTest { INDArray g1s = g1map.get(s); INDArray g2s = g2map.get(s); - System.out.println("-------"); - System.out.println("Variable: " + s); - System.out.println(Arrays.toString(g1s.dup().data().asFloat())); - System.out.println(Arrays.toString(g2s.dup().data().asFloat())); +// System.out.println("-------"); +// System.out.println("Variable: " + s); +// System.out.println(Arrays.toString(g1s.dup().data().asFloat())); +// System.out.println(Arrays.toString(g2s.dup().data().asFloat())); assertNotEquals(s, g1s, g2s); } @@ -507,7 +507,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { for (boolean bidirectional : isBidirectional) { for (PoolingType pt : poolingTypes) { - System.out.println("Starting test: bidirectional = " + bidirectional + ", poolingType = " + pt); +// System.out.println("Starting test: bidirectional = " + bidirectional + ", poolingType = " + pt); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) .activation(Activation.TANH).list().layer(0, bidirectional diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java index 3a2153e1e..b9a15ccb2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java @@ -51,7 +51,6 @@ public class TestFrozenLayers extends BaseDL4JTest { for(double l1 : new double[]{0.0, 0.3}){ for( double l2 : new double[]{0.0, 0.4}){ - System.out.println("--------------------"); String msg = "l1=" + l1 + ", l2=" + l2; FineTuneConfiguration ftc = new FineTuneConfiguration.Builder() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java index 8a15197f1..f0fadc968 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java @@ -273,8 +273,9 @@ public class TransferLearningComplex extends BaseDL4JTest { MultiDataSet rand = new MultiDataSet(new INDArray[] {Nd4j.rand(2, 2), Nd4j.rand(2, 2)}, new INDArray[] {Nd4j.rand(2, 2), Nd4j.rand(2, 3)}); modelNow.fit(rand); - log.info(modelNow.summary()); - log.info(modelNow.summary(InputType.feedForward(2),InputType.feedForward(2))); - +// log.info(modelNow.summary()); +// log.info(modelNow.summary(InputType.feedForward(2),InputType.feedForward(2))); + modelNow.summary(); + modelNow.summary(InputType.feedForward(2),InputType.feedForward(2)); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java index 22521b23b..75b30ffd7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java @@ -195,9 +195,10 @@ public class TransferLearningHelperTest extends BaseDL4JTest { assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params()); assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params()); - log.info(modelIdentical.summary()); - log.info(helper.unfrozenGraph().summary()); - +// log.info(modelIdentical.summary()); +// log.info(helper.unfrozenGraph().summary()); + modelIdentical.summary(); + helper.unfrozenGraph().summary(); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/util/TestDataSetConsumer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/util/TestDataSetConsumer.java index 232751bb4..7439c5ad2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/util/TestDataSetConsumer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/util/TestDataSetConsumer.java @@ -84,8 +84,8 @@ public class TestDataSetConsumer { count.incrementAndGet(); - if (count.get() % 100 == 0) - logger.info("Passed {} datasets...", count.get()); +// if (count.get() % 100 == 0) +// logger.info("Passed {} datasets...", count.get()); return count.get(); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java index 6975de250..e68cf133b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java @@ -186,7 +186,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer)); network.init(); - TrainingListener listener = new ScoreIterationListener(1); + TrainingListener listener = new ScoreIterationListener(10); network.setListeners(Collections.singletonList(listener)); double oldScore = network.score(data); for( int i=0; i<100; i++ ) { @@ -204,7 +204,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { data.normalizeZeroMeanZeroUnitVariance(); MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer)); network.init(); - TrainingListener listener = new ScoreIterationListener(1); + TrainingListener listener = new ScoreIterationListener(10); network.setListeners(Collections.singletonList(listener)); double firstScore = network.score(data); @@ -223,7 +223,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { data.normalizeZeroMeanZeroUnitVariance(); MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer)); network.init(); - TrainingListener listener = new ScoreIterationListener(1); + TrainingListener listener = new ScoreIterationListener(10); network.setListeners(Collections.singletonList(listener)); double oldScore = network.score(data); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java index a7ce1622f..c2f5cd595 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java @@ -66,7 +66,7 @@ import static org.junit.Assert.assertTrue; public class TestOptimizers extends BaseDL4JTest { //For debugging. - private static final boolean PRINT_OPT_RESULTS = true; + private static final boolean PRINT_OPT_RESULTS = false; @Test public void testOptimizersBasicMLPBackprop() { diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java similarity index 80% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java index a501d4e1f..bae025caf 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java @@ -14,11 +14,13 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.optimize.solvers.accumulation; +package org.deeplearning4j.optimize.solver.accumulation; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator; +import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler; import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.FixedThresholdAlgorithm; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; @@ -34,15 +36,26 @@ import static org.junit.Assert.assertTrue; @Slf4j public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 1200000L; + } + /** * This test ensures, that memory amount assigned to buffer is enough for any number of updates * @throws Exception */ @Test public void testStore1() throws Exception { - int numParams = 100000; - - int workers[] = new int[] {2, 4, 8}; + int numParams; + int[] workers; + if(isIntegrationTests()){ + numParams = 100000; + workers = new int[] {2, 4, 8}; + } else { + numParams = 10000; + workers = new int[] {2, 3}; + } for (int numWorkers : workers) { EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3),null, null, false); @@ -57,8 +70,8 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { accumulator.receiveUpdate(encoded); // just purge updates, like they were consumed - for (int i = 0; i < accumulator.messages.size(); i++) { - accumulator.messages.get(i).clear(); + for (int i = 0; i < accumulator.getMessages().size(); i++) { + accumulator.getMessages().get(i).clear(); } } } @@ -72,7 +85,13 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { */ @Test public void testEncodingLimits1() throws Exception { - int numParams = 100000; + int numParams; + if(isIntegrationTests()){ + numParams = 100000; + } else { + numParams = 10000; + } + EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, null, false); for (int e = 10; e < numParams / 5; e++) { diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/IndexedTailTest.java similarity index 96% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/IndexedTailTest.java index 68de05ce5..512b77c35 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/IndexedTailTest.java @@ -14,12 +14,13 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.optimize.solvers.accumulation; +package org.deeplearning4j.optimize.solver.accumulation; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.optimize.solvers.accumulation.IndexedTail; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.factory.Nd4j; @@ -97,9 +98,9 @@ public class IndexedTailTest extends BaseDL4JTest { assertEquals(-1, tail.maxAppliedIndexEverywhere()); // 2 consumers consumed 2 elements, and 1 consumer consumed 3 elements - tail.positions.get(11L).set(2); - tail.positions.get(22L).set(2); - tail.positions.get(33L).set(3); + tail.getPositions().get(11L).set(2); + tail.getPositions().get(22L).set(2); + tail.getPositions().get(33L).set(3); // all elements including this index are safe to remove, because they were consumed everywhere assertEquals(2, tail.maxAppliedIndexEverywhere()); @@ -197,10 +198,10 @@ public class IndexedTailTest extends BaseDL4JTest { Nd4j.getExecutioner().commit(); } - assertTrue(tail.collapsedMode.get()); + assertTrue(tail.getCollapsedMode().get()); assertEquals(1, tail.updatesSize()); - val array = tail.updates.get(32L); + val array = tail.getUpdates().get(32L); assertNotNull(array); assertEquals(sum, (int) array.getDouble(0)); } @@ -242,7 +243,7 @@ public class IndexedTailTest extends BaseDL4JTest { final long[] sums = new long[numReaders]; val readers = new ArrayList(); for (int e = 0; e < numReaders; e++) { - val f = e; + final int f = e; val t = new Thread(new Runnable() { @Override public void run() { @@ -297,7 +298,7 @@ public class IndexedTailTest extends BaseDL4JTest { final long[] sums = new long[numReaders]; val readers = new ArrayList(); for (int e = 0; e < numReaders; e++) { - val f = e; + final int f = e; val t = new Thread(new Runnable() { @Override public void run() { @@ -371,7 +372,7 @@ public class IndexedTailTest extends BaseDL4JTest { final long[] sums = new long[numReaders]; val readers = new ArrayList(); for (int e = 0; e < numReaders; e++) { - val f = e; + final int f = e; val t = new Thread(new Runnable() { @Override public void run() { diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/SmartFancyBlockingQueueTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java similarity index 98% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/SmartFancyBlockingQueueTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java index 725a9db8a..6496fd9e3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/SmartFancyBlockingQueueTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java @@ -14,12 +14,13 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.optimize.solvers.accumulation; +package org.deeplearning4j.optimize.solver.accumulation; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.optimize.solvers.accumulation.SmartFancyBlockingQueue; import org.deeplearning4j.util.ThreadUtils; import org.junit.Ignore; import org.junit.Test; @@ -292,10 +293,10 @@ public class SmartFancyBlockingQueueTest extends BaseDL4JTest { } // each reader will read 250 updates. supposedly equal :) - val means = new long[4]; + final long[] means = new long[4]; val readers = new ArrayList(); for (int e = 0; e < 4; e++) { - val f = e; + final int f = e; means[f] = 0; val t = new Thread(new Runnable() { @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/ThresholdAlgorithmTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/ThresholdAlgorithmTests.java similarity index 98% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/ThresholdAlgorithmTests.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/ThresholdAlgorithmTests.java index 52744bf12..578674ed6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/ThresholdAlgorithmTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/ThresholdAlgorithmTests.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.optimize.solvers.accumulation; +package org.deeplearning4j.optimize.solver.accumulation; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm; diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/listeners/ScoreStatTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java similarity index 96% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/listeners/ScoreStatTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java index 5a3588e25..fdc86c6cc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/listeners/ScoreStatTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java @@ -1,5 +1,6 @@ -package org.deeplearning4j.optimize.listeners; +package org.deeplearning4j.optimizer.listener; +import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener; import org.junit.Ignore; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestParamAndGradientIterationListener.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestParamAndGradientIterationListener.java deleted file mode 100644 index 797be51cc..000000000 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestParamAndGradientIterationListener.java +++ /dev/null @@ -1,79 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.optimizer.listener; - - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.conf.layers.OutputLayer; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.optimize.api.TrainingListener; -import org.deeplearning4j.optimize.listeners.ParamAndGradientIterationListener; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.learning.config.Sgd; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -import java.io.File; - -import static org.junit.Assert.assertEquals; - -public class TestParamAndGradientIterationListener extends BaseDL4JTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - @Test - public void test() throws Exception { - - IrisDataSetIterator iter = new IrisDataSetIterator(30, 150); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(1e-5)) - .list().layer(0, new DenseLayer.Builder().nIn(4).nOut(20).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(30).nOut(3).build()) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - File f = testDir.newFile("paramAndGradTest.txt"); - TrainingListener listener = ParamAndGradientIterationListener.builder().outputToFile(true) - .file(f) - .outputToConsole(true).outputToLogger(false).iterations(2).printHeader(true).printMean(false) - .printMinMax(false).printMeanAbsValue(true).delimiter("\t").build(); - net.setListeners(listener); - - for (int i = 0; i < 2; i++) { - net.fit(iter); - } - - - } - - - - -} diff --git a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java similarity index 98% rename from deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java index f781914ec..ae46692fc 100644 --- a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java @@ -17,6 +17,7 @@ package org.deeplearning4j.perf.listener; import org.apache.commons.io.FileUtils; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; diff --git a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java similarity index 97% rename from deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java index 19436ce9d..b9589398b 100644 --- a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java @@ -16,6 +16,7 @@ package org.deeplearning4j.perf.listener; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Ignore; import org.junit.Test; import oshi.json.SystemInfo; diff --git a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java similarity index 98% rename from deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java index 69edb363b..6ce531f12 100644 --- a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java @@ -16,6 +16,7 @@ package org.deeplearning4j.perf.listener; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java index f65783bce..7a818810a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java @@ -91,7 +91,7 @@ public class BarnesHutTsneTest extends BaseDL4JTest { .useAdaGrad(false).build(); b.fit(data); - log.info("Result: {}", b.getData()); +// log.info("Result: {}", b.getData()); val exp = Nd4j.createFromArray(new double[]{-3.5318212819287327, 35.40331834897696, 3.890809489531651, -1.291195609955519, -42.854099388207466, 7.8761368019456635, 28.798057251442877, 7.1456564000935225, 2.9518396278984786, -42.860181054199636, -34.989343304202, -108.99770355680282, 31.78123839126566, -29.322118879730205, 163.87558311206212, 2.9538984612478396, 31.419519824305546, 13.105400907817279, 25.46987139120746, -43.27317406736858, 32.455151773056144, 25.28067703547214, 0.005442008567682552, 21.005029233370358, -61.71390311950051, 5.218417653362599, 47.15762099517554, 8.834739256343404, 17.845790108867153, -54.31654219224107, -18.71285871476804, -16.446982180909007, -71.22568781913213, -12.339975548387091, 70.49096598213703, 25.022454385237456, -14.572652938207126, -5.320080866729078, 1.5874449933639676, -40.60960510287835, -31.98564381157643, -95.40875746933808, 19.196346639002364, -38.80930682421929, 135.00454225923906, 5.277879540549592, 30.79963767087089, -0.007276462027131683, 31.278796123365815, -38.47381680049993, 10.415728497075905, 36.567265019013085, -7.406587944733211, -18.376174615781114, -45.26976962854271}).reshape(-1, 5); @@ -103,7 +103,7 @@ public class BarnesHutTsneTest extends BaseDL4JTest { assertArrayEquals(exp.data().asDouble(), b.getData().data().asDouble(), eps); } - @Test + @Test(timeout = 300000) public void testTsne() throws Exception { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); Nd4j.getRandom().setSeed(123); @@ -178,7 +178,7 @@ public class BarnesHutTsneTest extends BaseDL4JTest { INDArray data = iter.next().getFeatures(); INDArray perplexityOutput = b.computeGaussianPerplexity(data, 30.0); - System.out.println(perplexityOutput); +// System.out.println(perplexityOutput); } @Test @@ -217,17 +217,17 @@ public class BarnesHutTsneTest extends BaseDL4JTest { StopWatch watch = new StopWatch(); watch.start(); b.fit(data); - System.out.println(b.getData()); +// System.out.println(b.getData()); watch.stop(); File outDir = testDir.newFolder(); ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt"); List labelsList = IOUtils.readLines(labels.getInputStream()); b.saveAsFile(/*labelsList,*/ new File(outDir, "raw.txt").getAbsolutePath()); - System.out.println(b.getData()); +// System.out.println(b.getData()); System.out.println("Fit done in " + watch); assertEquals(2500, b.getData().size(0)); - System.out.println(b.getData()); +// System.out.println(b.getData()); INDArray a1 = b.getData().getRow(0); INDArray a2 = b.getData().getRow(1); @@ -338,7 +338,7 @@ public class BarnesHutTsneTest extends BaseDL4JTest { double[] dC = {-0.0618386320333619, -0.06266654959379839, 0.029998268806149204, 0.10780566335888186, -0.19449543068355346, -0.14763764361792697, 0.17493572758118422, 0.1926109839221966, -0.15176648259935419, 0.10974665709698186, 0.13102419155322598, 0.004941641352409449, 0.19159764518354974, -0.26332838053474944, -0.023631441261541583, 0.09838669432305949, 0.09709129638394683, -0.01605053000727605, 0.06566171635025217, -0.17325078066035252, -0.1090854255505605, 0.023350644966904276, 0.075192354899586, -0.08278373866517603, 0.18431338134579323, 0.2766031655578053, -0.17557907233268688, 0.10616148241800637, -0.09999024423215641, -0.017181932145255287, 0.06711331400576945, -0.01388231800826619, -0.10248189290485302, 0.20786521034824304, 0.11254913977572988, -0.289564646781519, 0.13491805919337516, -0.07504249344962562, 0.004154656287570634, -0.10516715438388784, -0.27984655075804576, 0.09811828071286613, 0.03684521473995052, -0.054645216532387256, -0.18147132772800725, 0.027588750493223044, 0.214734364419479, -0.026729138234415008, -0.28410504978879136, 0.007015481601883835, 0.04427981739424874, -0.059253265830134655, -0.05325479031206952, -0.11319889109674944, 0.1530133971867549}; INDArray actual = gradient.getGradientFor("yIncs"); - System.out.println(actual); +// System.out.println(actual); assertArrayEquals(dC, actual.reshape(1,55).toDoubleVector(), 1e-05); } @@ -482,8 +482,8 @@ public class BarnesHutTsneTest extends BaseDL4JTest { List results = new ArrayList<>(); List distances = new ArrayList<>(); tree.search(target, 11, results, distances); - System.out.println("Results:" + results); - System.out.println("Distances:" + distances); +// System.out.println("Results:" + results); +// System.out.println("Distances:" + distances); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java index 4dcbce538..867e96f09 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java @@ -28,7 +28,9 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitRelu; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.util.ModelSerializer; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; @@ -42,6 +44,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; import org.nd4j.resources.Resources; import java.io.File; +import java.sql.Time; import static org.junit.Assert.*; @@ -55,6 +58,9 @@ import static org.junit.Assert.*; */ public class RegressionTest050 extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index fa0fc335f..cf75700f8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -250,7 +250,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { sd.evaluate(iter, "softmax", rEvalSd); assertEquals(rEvalDl4j, rEvalSd); - System.out.println("---------------------------------"); +// System.out.println("---------------------------------"); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java index bdb72d4b6..5975d1c1e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java @@ -47,6 +47,11 @@ import static org.junit.Assert.*; public class CrashReportingUtilTest extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 120000; + } + @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java index fda013533..5973bae71 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.preprocessor.Normalizer; @@ -54,6 +55,9 @@ public class ModelGuesserTest extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Rule + public Timeout timeout = Timeout.seconds(300); + @Test public void testModelGuessFile() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java index a66207cd2..a704a1899 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java @@ -51,7 +51,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("MultiLayerNetwork", vr0.getFormatType()); assertEquals(MultiLayerNetwork.class, vr0.getFormatClass()); assertNull(vr0.getException()); - System.out.println(vr0.toString()); +// System.out.println(vr0.toString()); //Test empty file File f1 = new File(f, "empty.bin"); @@ -63,7 +63,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("MultiLayerNetwork", vr1.getFormatType()); assertEquals(MultiLayerNetwork.class, vr1.getFormatClass()); assertNull(vr1.getException()); - System.out.println(vr1.toString()); +// System.out.println(vr1.toString()); //Test invalid zip file File f2 = new File(f, "notReallyZip.zip"); @@ -75,7 +75,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("MultiLayerNetwork", vr2.getFormatType()); assertEquals(MultiLayerNetwork.class, vr2.getFormatClass()); assertNotNull(vr2.getException()); - System.out.println(vr2.toString()); +// System.out.println(vr2.toString()); //Test valid zip, but missing configuration File f3 = new File(f, "modelNoConfig.zip"); @@ -92,7 +92,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("MultiLayerNetwork", vr3.getFormatType()); assertEquals(MultiLayerNetwork.class, vr3.getFormatClass()); assertNull(vr3.getException()); - System.out.println(vr3.toString()); +// System.out.println(vr3.toString()); //Test valid sip, but missing params @@ -110,7 +110,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("MultiLayerNetwork", vr4.getFormatType()); assertEquals(MultiLayerNetwork.class, vr4.getFormatClass()); assertNull(vr4.getException()); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); //Test valid model @@ -122,7 +122,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("MultiLayerNetwork", vr5.getFormatType()); assertEquals(MultiLayerNetwork.class, vr5.getFormatClass()); assertNull(vr5.getException()); - System.out.println(vr5.toString()); +// System.out.println(vr5.toString()); //Test valid model with corrupted JSON @@ -141,7 +141,7 @@ public class ModelValidatorTests extends BaseDL4JTest { bytes = IOUtils.toByteArray(zis); } zo.write(bytes); - System.out.println("WROTE: " + ze.getName()); +// System.out.println("WROTE: " + ze.getName()); } } } @@ -153,7 +153,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("MultiLayerNetwork", vr6.getFormatType()); assertEquals(MultiLayerNetwork.class, vr6.getFormatClass()); assertNotNull(vr6.getException()); - System.out.println(vr6.toString()); +// System.out.println(vr6.toString()); } @@ -169,7 +169,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("ComputationGraph", vr0.getFormatType()); assertEquals(ComputationGraph.class, vr0.getFormatClass()); assertNull(vr0.getException()); - System.out.println(vr0.toString()); +// System.out.println(vr0.toString()); //Test empty file File f1 = new File(f, "empty.bin"); @@ -181,7 +181,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("ComputationGraph", vr1.getFormatType()); assertEquals(ComputationGraph.class, vr1.getFormatClass()); assertNull(vr1.getException()); - System.out.println(vr1.toString()); +// System.out.println(vr1.toString()); //Test invalid zip file File f2 = new File(f, "notReallyZip.zip"); @@ -193,7 +193,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("ComputationGraph", vr2.getFormatType()); assertEquals(ComputationGraph.class, vr2.getFormatClass()); assertNotNull(vr2.getException()); - System.out.println(vr2.toString()); +// System.out.println(vr2.toString()); //Test valid zip, but missing configuration File f3 = new File(f, "modelNoConfig.zip"); @@ -210,7 +210,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("ComputationGraph", vr3.getFormatType()); assertEquals(ComputationGraph.class, vr3.getFormatClass()); assertNull(vr3.getException()); - System.out.println(vr3.toString()); +// System.out.println(vr3.toString()); //Test valid sip, but missing params @@ -228,7 +228,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("ComputationGraph", vr4.getFormatType()); assertEquals(ComputationGraph.class, vr4.getFormatClass()); assertNull(vr4.getException()); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); //Test valid model @@ -240,7 +240,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("ComputationGraph", vr5.getFormatType()); assertEquals(ComputationGraph.class, vr5.getFormatClass()); assertNull(vr5.getException()); - System.out.println(vr5.toString()); +// System.out.println(vr5.toString()); //Test valid model with corrupted JSON @@ -259,7 +259,7 @@ public class ModelValidatorTests extends BaseDL4JTest { bytes = IOUtils.toByteArray(zis); } zo.write(bytes); - System.out.println("WROTE: " + ze.getName()); +// System.out.println("WROTE: " + ze.getName()); } } } @@ -271,7 +271,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertEquals("ComputationGraph", vr6.getFormatType()); assertEquals(ComputationGraph.class, vr6.getFormatClass()); assertNotNull(vr6.getException()); - System.out.println(vr6.toString()); +// System.out.println(vr6.toString()); } diff --git a/deeplearning4j/deeplearning4j-cuda/pom.xml b/deeplearning4j/deeplearning4j-cuda/pom.xml index 95c5f5deb..dfdc76efb 100644 --- a/deeplearning4j/deeplearning4j-cuda/pom.xml +++ b/deeplearning4j/deeplearning4j-cuda/pom.xml @@ -83,6 +83,12 @@ junit test + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index 2f4532823..1b8c42e14 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -123,8 +123,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -214,8 +212,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { System.out.println(testName + "- activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); - for (int j = 0; j < mln.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// for (int j = 0; j < mln.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -277,8 +275,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); @@ -340,8 +338,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); @@ -397,8 +395,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -468,8 +466,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -602,9 +600,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - for (int i = 0; i < 4; i++) { - System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); - } +// for (int i = 0; i < 4; i++) { +// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); +// } String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; @@ -663,9 +661,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - for (int j = 0; j < net.getLayers().length; j++) { - System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams()); - } +// for (int j = 0; j < net.getLayers().length; j++) { +// System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams()); +// } String msg = "Minibatch=" + minibatchSize + ", inDepth=" + inputDepth + ", height=" + height + ", kernelSize=" + k; @@ -726,18 +724,19 @@ public class CNNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - for (int i = 0; i < net.getLayers().length; i++) { - System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); - } +// for (int i = 0; i < net.getLayers().length; i++) { +// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); +// } String msg = "Minibatch=" + minibatchSize + ", inDepth=" + inputDepth + ", height=" + height + ", kernelSize=" + k + ", stride = " + stride + ", convLayer first = " + convFirst; System.out.println(msg); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, - labels, null, null, true, 128); + boolean gradOK = GradientCheckUtil.checkGradients( + new GradientCheckUtil.MLNConfig().net(net) + .input(input).labels(labels) + .subset(true).maxPerParam(128)); assertTrue(msg, gradOK); @@ -805,8 +804,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// for (int j = 0; j < net.getnLayers(); j++) +// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -871,16 +870,18 @@ public class CNNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - for (int j = 0; j < net.getLayers().length; j++) { - System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams()); - } +// for (int j = 0; j < net.getLayers().length; j++) { +// System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams()); +// } String msg = " - mb=" + minibatchSize + ", k=" + k + ", s=" + s + ", d=" + d + ", cm=" + cm; System.out.println(msg); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 100); + boolean gradOK = GradientCheckUtil.checkGradients( + new GradientCheckUtil.MLNConfig().net(net) + .input(input).labels(labels) + .subset(true).maxPerParam(100)); assertTrue(msg, gradOK); @@ -940,16 +941,18 @@ public class CNNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - for (int i = 0; i < net.getLayers().length; i++) { - System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); - } +// for (int i = 0; i < net.getLayers().length; i++) { +// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); +// } String msg = " - mb=" + minibatchSize + ", k=" + k + ", nIn=" + nIn + ", depthMul=" + depthMultiplier + ", s=" + s + ", cm=" + cm; System.out.println(msg); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 256); + boolean gradOK = GradientCheckUtil.checkGradients( + new GradientCheckUtil.MLNConfig().net(net) + .input(input).labels(labels) + .subset(true).maxPerParam(256)); assertTrue(msg, gradOK); @@ -1013,16 +1016,18 @@ public class CNNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - for (int i = 0; i < net.getLayers().length; i++) { - System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); - } +// for (int i = 0; i < net.getLayers().length; i++) { +// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); +// } String msg = " - mb=" + minibatchSize + ", k=" + k + ", s=" + s + ", d=" + d + ", cm=" + cm; System.out.println(msg); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 50); //Most params are in output layer + boolean gradOK = GradientCheckUtil.checkGradients( + new GradientCheckUtil.MLNConfig().net(net) + .input(input).labels(labels) + .subset(true).maxPerParam(50)); assertTrue(msg, gradOK); @@ -1097,9 +1102,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - for (int i = 0; i < net.getLayers().length; i++) { - System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); - } +// for (int i = 0; i < net.getLayers().length; i++) { +// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); +// } String msg = (subsampling ? "subsampling" : "conv") + " - mb=" + minibatchSize + ", k=" + k + ", s=" + s + ", d=" + d + ", cm=" + cm; @@ -1172,12 +1177,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); - for (int j = 0; j < net.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 160); + boolean gradOK = GradientCheckUtil.checkGradients( + new GradientCheckUtil.MLNConfig().net(net) + .input(input).labels(labels) + .subset(true).maxPerParam(160)); assertTrue(msg, gradOK); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java index f1ea9f1e1..9e43f042b 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java @@ -305,7 +305,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest { //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, false, -1, excludeParams, null); assertTrue(gradOK); } @@ -417,7 +417,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest { } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 32); + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 32, null, null); assertTrue(gradOK); } @@ -655,9 +655,12 @@ public class CuDNNGradientChecks extends BaseDL4JTest { }; log.info("*** Starting test: " + msg + " ***"); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, null, - false, -1, null, c); + boolean gradOK = GradientCheckUtil.checkGradients( + new GradientCheckUtil.MLNConfig().net(mln).epsilon(DEFAULT_EPS) + .maxRelError(DEFAULT_MAX_REL_ERROR).minAbsoluteError(DEFAULT_MIN_ABS_ERROR) + .print(PRINT_RESULTS ? GradientCheckUtil.PrintMode.ZEROS : GradientCheckUtil.PrintMode.FAILURES_ONLY) + .exitOnFirstError(RETURN_ON_FIRST_FAILURE) + .input(f).labels(l).callEachIter(c)); assertTrue(msg, gradOK); TestUtils.testModelSerialization(mln); @@ -691,7 +694,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest { //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, excludeParams); + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, null, null, false, -1, excludeParams, null); assertTrue(gradOK); diff --git a/deeplearning4j/deeplearning4j-graph/pom.xml b/deeplearning4j/deeplearning4j-graph/pom.xml index 9c4b25ac3..ebc6740d9 100644 --- a/deeplearning4j/deeplearning4j-graph/pom.xml +++ b/deeplearning4j/deeplearning4j-graph/pom.xml @@ -51,6 +51,13 @@ test + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + + diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/BaseDL4JTest.java deleted file mode 100644 index b1b6df5dd..000000000 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/BaseDL4JTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.graph; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java index 551750f7c..1a5a27918 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java @@ -17,7 +17,7 @@ package org.deeplearning4j.graph.data; import org.apache.commons.lang3.ArrayUtils; -import org.deeplearning4j.graph.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.graph.api.Edge; import org.deeplearning4j.graph.api.IGraph; import org.deeplearning4j.graph.data.impl.DelimitedEdgeLineProcessor; diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java index a06f40248..94e1a20bf 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java @@ -17,7 +17,7 @@ package org.deeplearning4j.graph.data; import org.apache.commons.lang3.ArrayUtils; -import org.deeplearning4j.graph.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.graph.api.Edge; import org.deeplearning4j.graph.api.IGraph; import org.deeplearning4j.graph.data.impl.WeightedEdgeLineProcessor; diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java index 74c7f7dc2..0dc456107 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java @@ -17,7 +17,7 @@ package org.deeplearning4j.graph.graph; import org.apache.commons.lang3.ArrayUtils; -import org.deeplearning4j.graph.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.graph.api.*; import org.deeplearning4j.graph.data.GraphLoader; import org.deeplearning4j.graph.iterator.RandomWalkIterator; diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java index 951a4c50b..39e91921a 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java @@ -16,7 +16,7 @@ package org.deeplearning4j.graph.models.deepwalk; -import org.deeplearning4j.graph.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.graph.data.GraphLoader; import org.deeplearning4j.graph.graph.Graph; import org.deeplearning4j.graph.iterator.GraphWalkIterator; diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java index 82d94fc46..d92c3bec1 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java @@ -17,7 +17,7 @@ package org.deeplearning4j.graph.models.deepwalk; import org.apache.commons.io.FilenameUtils; -import org.deeplearning4j.graph.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.graph.api.Edge; import org.deeplearning4j.graph.api.IGraph; import org.deeplearning4j.graph.data.GraphLoader; diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java index 5651eec2e..763aae822 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java @@ -16,7 +16,7 @@ package org.deeplearning4j.graph.models.deepwalk; -import org.deeplearning4j.graph.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import java.util.Arrays; diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml index 37002b5e1..7ebb82e75 100644 --- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml @@ -55,6 +55,13 @@ nd4j-api ${nd4j.version} + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java index af544b36e..3359e729f 100644 --- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java @@ -17,6 +17,7 @@ package org.deeplearning4j.plot; import lombok.val; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -25,7 +26,7 @@ import java.util.ArrayList; import static org.junit.Assert.assertTrue; -public class Test6058 { +public class Test6058 extends BaseDL4JTest { @Test public void test() throws Exception { diff --git a/deeplearning4j/deeplearning4j-modelimport/pom.xml b/deeplearning4j/deeplearning4j-modelimport/pom.xml index 223aebdaa..566bf6012 100644 --- a/deeplearning4j/deeplearning4j-modelimport/pom.xml +++ b/deeplearning4j/deeplearning4j-modelimport/pom.xml @@ -86,6 +86,12 @@ junit test + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + ch.qos.logback diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java index 6d6fc42c9..7841fdf27 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java @@ -108,6 +108,9 @@ public class KerasLayerConfiguration { private final String LAYER_CLASS_NAME_LEAKY_RELU = "LeakyReLU"; private final String LAYER_CLASS_NAME_PRELU = "PReLU"; private final String LAYER_CLASS_NAME_THRESHOLDED_RELU = "ThresholdedReLU"; + private final String LAYER_CLASS_NAME_RELU = "ReLU"; + private final String LAYER_CLASS_NAME_ELU = "ELU"; + private final String LAYER_CLASS_NAME_SOFTMAX = "Softmax"; private final String LAYER_CLASS_NAME_UPSAMPLING_1D = "UpSampling1D"; private final String LAYER_CLASS_NAME_UPSAMPLING_2D = "UpSampling2D"; private final String LAYER_CLASS_NAME_UPSAMPLING_3D = "UpSampling3D"; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasELU.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasELU.java new file mode 100644 index 000000000..2517ae0ac --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasELU.java @@ -0,0 +1,95 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations; + +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.ActivationLayer; +import org.deeplearning4j.nn.modelimport.keras.KerasLayer; +import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; +import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; +import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.ActivationELU; +import org.nd4j.linalg.activations.impl.ActivationLReLU; + +import java.util.Map; + +/** + * Imports ELU layer from Keras + * + * @author Alex Black + */ +public class KerasELU extends KerasLayer { + + + /** + * Constructor from parsed Keras layer configuration dictionary. + * + * @param layerConfig dictionary containing Keras layer configuration + * @throws InvalidKerasConfigurationException Invalid Keras config + * @throws UnsupportedKerasConfigurationException Unsupported Invalid Keras config + */ + public KerasELU(Map layerConfig) + throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + this(layerConfig, true); + } + + /** + * Constructor from parsed Keras layer configuration dictionary. + * + * @param layerConfig dictionary containing Keras layer configuration + * @param enforceTrainingConfig whether to enforce training-related configuration options + * @throws InvalidKerasConfigurationException Invalid Keras config + * @throws UnsupportedKerasConfigurationException Invalid Keras config + */ + public KerasELU(Map layerConfig, boolean enforceTrainingConfig) + throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + super(layerConfig, enforceTrainingConfig); + Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); + double alpha = 1.0; // Set default alpha to default in nd4j + String layerFieldLeakyReluAlpha = "alpha"; + if (innerConfig.containsKey(layerFieldLeakyReluAlpha)) { + alpha = (double) innerConfig.get(layerFieldLeakyReluAlpha); + } + IActivation leakyReLU = new ActivationELU(alpha); + this.layer = new ActivationLayer.Builder().name(this.layerName).activation(leakyReLU).build(); + } + + /** + * Get layer output type. + * + * @param inputType Array of InputTypes + * @return output type as InputType + * @throws InvalidKerasConfigurationException Invalid Keras config + */ + public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException { + if (inputType.length > 1) + throw new InvalidKerasConfigurationException( + "Keras Activation layer accepts only one input (received " + inputType.length + ")"); + return this.getActivationLayer().getOutputType(-1, inputType[0]); + } + + /** + * Get DL4J ActivationLayer. + * + * @return ActivationLayer + */ + public ActivationLayer getActivationLayer() { + return (ActivationLayer) this.layer; + } + +} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasReLU.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasReLU.java new file mode 100644 index 000000000..14c4b3d73 --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasReLU.java @@ -0,0 +1,99 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations; + +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.ActivationLayer; +import org.deeplearning4j.nn.modelimport.keras.KerasLayer; +import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; +import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; +import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.ActivationLReLU; +import org.nd4j.linalg.activations.impl.ActivationReLU; + +import java.util.Map; + +/** + * Imports ReLU layer from Keras + * + * @author Alex Black + */ +public class KerasReLU extends KerasLayer { + + /** + * Constructor from parsed Keras layer configuration dictionary. + * + * @param layerConfig dictionary containing Keras layer configuration + * @throws InvalidKerasConfigurationException Invalid Keras config + * @throws UnsupportedKerasConfigurationException Unsupported Invalid Keras config + */ + public KerasReLU(Map layerConfig) + throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + this(layerConfig, true); + } + + /** + * Constructor from parsed Keras layer configuration dictionary. + * + * @param layerConfig dictionary containing Keras layer configuration + * @param enforceTrainingConfig whether to enforce training-related configuration options + * @throws InvalidKerasConfigurationException Invalid Keras config + * @throws UnsupportedKerasConfigurationException Invalid Keras config + */ + public KerasReLU(Map layerConfig, boolean enforceTrainingConfig) + throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + super(layerConfig, enforceTrainingConfig); + Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); + Double maxValue = (Double) innerConfig.get("max_value"); + double negativeSlope = 0.0; + double threshold = 0.0; + if (innerConfig.containsKey("negative_slope")) { + negativeSlope = (double) innerConfig.get("negative_slope"); + } + if (innerConfig.containsKey("threshold")) { + threshold = (double) innerConfig.get("threshold"); + } + + this.layer = new ActivationLayer.Builder().name(this.layerName) + .activation(new ActivationReLU(maxValue, threshold, negativeSlope)).build(); + } + + /** + * Get layer output type. + * + * @param inputType Array of InputTypes + * @return output type as InputType + * @throws InvalidKerasConfigurationException Invalid Keras config + */ + public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException { + if (inputType.length > 1) + throw new InvalidKerasConfigurationException( + "Keras Activation layer accepts only one input (received " + inputType.length + ")"); + return this.getActivationLayer().getOutputType(-1, inputType[0]); + } + + /** + * Get DL4J ActivationLayer. + * + * @return ActivationLayer + */ + public ActivationLayer getActivationLayer() { + return (ActivationLayer) this.layer; + } + +} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasSoftmax.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasSoftmax.java new file mode 100644 index 000000000..884c55ef1 --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasSoftmax.java @@ -0,0 +1,85 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations; + +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.ActivationLayer; +import org.deeplearning4j.nn.modelimport.keras.KerasLayer; +import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; +import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; +import org.nd4j.linalg.activations.impl.ActivationSoftmax; + +import java.util.Map; + +/** + * Imports Softmax layer from Keras + * + * @author Alex Black + */ +public class KerasSoftmax extends KerasLayer { + + /** + * Constructor from parsed Keras layer configuration dictionary. + * + * @param layerConfig dictionary containing Keras layer configuration + * @throws InvalidKerasConfigurationException Invalid Keras config + * @throws UnsupportedKerasConfigurationException Unsupported Invalid Keras config + */ + public KerasSoftmax(Map layerConfig) + throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + this(layerConfig, true); + } + + /** + * Constructor from parsed Keras layer configuration dictionary. + * + * @param layerConfig dictionary containing Keras layer configuration + * @param enforceTrainingConfig whether to enforce training-related configuration options + * @throws InvalidKerasConfigurationException Invalid Keras config + * @throws UnsupportedKerasConfigurationException Invalid Keras config + */ + public KerasSoftmax(Map layerConfig, boolean enforceTrainingConfig) + throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + super(layerConfig, enforceTrainingConfig); + + this.layer = new ActivationLayer.Builder().name(this.layerName).activation(new ActivationSoftmax()).build(); + } + + /** + * Get layer output type. + * + * @param inputType Array of InputTypes + * @return output type as InputType + * @throws InvalidKerasConfigurationException Invalid Keras config + */ + public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException { + if (inputType.length > 1) + throw new InvalidKerasConfigurationException( + "Keras Activation layer accepts only one input (received " + inputType.length + ")"); + return this.getActivationLayer().getOutputType(-1, inputType[0]); + } + + /** + * Get DL4J ActivationLayer. + * + * @return ActivationLayer + */ + public ActivationLayer getActivationLayer() { + return (ActivationLayer) this.layer; + } + +} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java index 3494ecf49..1428b6322 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java @@ -25,9 +25,7 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput; -import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasLeakyReLU; -import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasPReLU; -import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasThresholdedReLU; +import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*; import org.deeplearning4j.nn.modelimport.keras.layers.core.*; import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding; @@ -313,6 +311,12 @@ public class KerasLayerUtils { if (lambdaLayer != null){ layer = new KerasLambda(layerConfig, enforceTrainingConfig, lambdaLayer); } + } else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_RELU())){ + layer = new KerasReLU(layerConfig, enforceTrainingConfig); + } else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_ELU())){ + layer = new KerasELU(layerConfig, enforceTrainingConfig); + } else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){ + layer = new KerasSoftmax(layerConfig, enforceTrainingConfig); } if (layer == null){ Class customConfig = customLayers.get(layerClassName); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/BaseDL4JTest.java deleted file mode 100644 index d7ae7e2ca..000000000 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/BaseDL4JTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.modelimport.keras; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java index dcfd53518..92c55c891 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras; import org.apache.commons.io.FileUtils; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.utils.DL4JKerasModelValidator; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Rule; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java index 6043d7d48..ee5129b84 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java @@ -23,7 +23,7 @@ import org.datavec.api.split.NumberedFileInputSplit; import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; @@ -33,6 +33,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.activations.impl.ActivationHardSigmoid; import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.api.ndarray.INDArray; @@ -60,6 +61,9 @@ public class FullModelComparisons extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Rule + public Timeout timeout = Timeout.seconds(300); + @Test public void lstmTest() throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException, InterruptedException { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java index 20b80d30f..4aae27af3 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java index 554a2c2d1..d4b8e453a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java @@ -20,7 +20,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index d776ed63e..4d4bf067e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -21,12 +21,13 @@ import lombok.val; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -255,10 +256,8 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { } } - @Test + @Test @Ignore("AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441") public void ReshapeEmbeddingConcatTest() throws Exception{ - //TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441 - try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) { ComputationGraphConfiguration config = new KerasModel().modelBuilder().modelJsonInputStream(is) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java index 8ac231e12..583634264 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations; import org.deeplearning4j.nn.conf.distribution.*; import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java index b5d3c9ab6..cf51831a2 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java index c14377b31..cdf5faca3 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java @@ -20,7 +20,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.layers.custom.KerasLRN; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java index 97ae4318f..1d55a5d2c 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.e2e; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index d4f458a39..b17c215cb 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; @@ -724,6 +724,29 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } } + @Test + public void testActivationLayers() throws Exception { + String[] names = new String[]{ + "ELU_0_model.h5", + "LeakyReLU_0_model.h5", + "ReLU_0_model.h5", + "ReLU_1_model.h5", + "ReLU_2_model.h5", + "ReLU_3_model.h5", + "Softmax_0_model.h5", + "ThresholdReLU_0_model.h5", + }; + + for(String name : names ){ + System.out.println("Starting test: " + name); + String modelPath = "modelimport/keras/examples/activations/" + name; + String inputsOutputPath = "modelimport/keras/examples/activations/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); + + importEndModelTest(modelPath, inputsOutputPath, true, true, + true, true, false, null, null); + } + } + private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception { return importFunctionalModelH5Test(modelPath, null, false); } @@ -991,8 +1014,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } Nd4j.setDataType(DataType.DOUBLE); - boolean passed = GradientCheckUtil.checkGradients(netToTest, eps, max_rel_error, min_abs_error, true, false, - input, labels, null, null, true, 9); + boolean passed = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(netToTest).input(input) + .labels(labels).subset(true).maxPerParam(9)); assertTrue("Gradient check failed", passed); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java index 8bd6e779d..1da4bf5cc 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.e2e; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java index dcfe7bfda..f428a0dbd 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.e2e; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java index 7770e7816..e16da1bfa 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activation; import org.deeplearning4j.nn.conf.layers.ActivationLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java index b22963e2a..ee7c0ab48 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activation; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.PReLULayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java index 02bf24e1d..1b24c1ff2 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activation; import org.deeplearning4j.nn.conf.layers.ActivationLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java index eccaeb536..1b3c98f60 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java index 2a01a1d8b..411127bd7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java index 449dc10cc..f12d66f56 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java index bd0d6e012..8f61d7038 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java index ff0ba8f3d..11177c5dd 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java index 1676f6136..86fb5591b 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java index 95f5d7485..88226704e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java index 6ae3065b6..392195e0e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java index 3675d46a8..177e2e717 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.Deconvolution2D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java index 364c50e72..6b173ba8e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java index 7d05a1b67..f8ec7e163 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java index aec4278e2..c93a0fe32 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.Upsampling1D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java index cea117f8f..8033f24e7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.Upsampling2D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java index a8e564340..578d92276 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.Upsampling3D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java index 4cc9cc2cb..aa2d96653 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java index 5f72cbcf2..08d6e57a9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java index c0a60defd..960a0194e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java index d7f4d8ad9..e19178ea3 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.layers.ActivationLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java index cca2515a8..f2ad5c242 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java index ff5c49cc5..943a76b26 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.DropoutLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java index 144d24ab8..76e5c6239 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java index 1f2400426..218e2fc7c 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java index 6c439fb95..2a448bf4c 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java index 19d5ce623..7adfa09c7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java index 71ec2f468..dc17a4946 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.dropout.SpatialDropout; import org.deeplearning4j.nn.conf.layers.DropoutLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java index b171e063f..55274209d 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.embeddings; import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java index 428d5d99e..cc91c89bb 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java @@ -20,7 +20,7 @@ import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LocallyConnected1D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java index 1ea69e06a..cb05d4597 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java @@ -20,7 +20,7 @@ import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LocallyConnected2D; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java index 6f34e1684..3a632f2b4 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.noise; import org.deeplearning4j.nn.conf.dropout.AlphaDropout; import org.deeplearning4j.nn.conf.layers.DropoutLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java index 7ca51b37d..b759dd370 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.noise; import org.deeplearning4j.nn.conf.dropout.GaussianDropout; import org.deeplearning4j.nn.conf.layers.DropoutLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java index 58abe77a4..be01a06c5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.noise; import org.deeplearning4j.nn.conf.dropout.GaussianNoise; import org.deeplearning4j.nn.conf.layers.DropoutLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java index 6fbf9ec43..e240b19f0 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.normalization; import org.deeplearning4j.nn.conf.layers.BatchNormalization; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java index 3b2716cd8..25acee41a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.pooling; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.PoolingType; import org.deeplearning4j.nn.conf.layers.Subsampling1DLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java index f4852d89a..137f8ca00 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.pooling; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.PoolingType; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java index 9026c7308..153e41103 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.pooling; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.PoolingType; import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java index 3b82f14ae..60b1044d9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java @@ -21,7 +21,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java index b760a90e1..2abcd3e2a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.recurrent; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java index 91969bc09..0613ecf67 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.wrappers; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java index f2a693d9a..d030158f0 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java @@ -16,7 +16,7 @@ package org.deeplearning4j.nn.modelimport.keras.optimizers; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java index 577e089f9..fc0cbfe9b 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.junit.Test; import org.nd4j.resources.Resources; @@ -30,7 +30,7 @@ import java.io.IOException; */ public class TimeSeriesGeneratorImportTest extends BaseDL4JTest { - @Test + @Test(timeout=300000) public void importTimeSeriesTest() throws IOException, InvalidKerasConfigurationException { String path = "modelimport/keras/preprocessing/timeseries_generator.json"; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorTest.java index 3a2da91a9..b068881e1 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java index 45114685b..935be4fbe 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.text; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.junit.Test; import org.nd4j.resources.Resources; @@ -35,7 +35,7 @@ public class TokenizerImportTest extends BaseDL4JTest { ClassLoader classLoader = getClass().getClassLoader(); - @Test + @Test(timeout=300000) public void importTest() throws IOException, InvalidKerasConfigurationException { String path = "modelimport/keras/preprocessing/tokenizer.json"; @@ -51,7 +51,7 @@ public class TokenizerImportTest extends BaseDL4JTest { } - @Test + @Test(timeout=300000) public void importNumWordsNullTest() throws IOException, InvalidKerasConfigurationException { String path = "modelimport/keras/preprocessing/tokenizer_num_words_null.json"; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java index a4fb6994b..cebb22fb4 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.text; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java index 7791e3417..75334bcd0 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.weights; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml index ab28d78c4..2d4a4da14 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml @@ -103,6 +103,12 @@ logback-classic test + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/BaseDL4JTest.java deleted file mode 100644 index 41107decf..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/BaseDL4JTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nearestneighbor.server; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java index b42c407e5..4555511ce 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nearestneighbor.server; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.clustering.sptree.DataPoint; import org.deeplearning4j.clustering.vptree.VPTree; import org.deeplearning4j.clustering.vptree.VPTreeFillSearch; @@ -24,7 +25,6 @@ import org.deeplearning4j.nearestneighbor.client.NearestNeighborsClient; import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest; import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult; import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResults; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -36,7 +36,6 @@ import java.io.File; import java.io.IOException; import java.net.ServerSocket; import java.util.List; -import java.util.UUID; import java.util.concurrent.Executor; import java.util.concurrent.Executors; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml index 87bb7e68e..fbe0ddccf 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml @@ -66,6 +66,12 @@ 2.10.3 test + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/BaseDL4JTest.java deleted file mode 100644 index 8b57f5dc0..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/BaseDL4JTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.clustering; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java index 618ee0c94..60552486e 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java @@ -16,11 +16,8 @@ package org.deeplearning4j.clustering.kdtree; -import org.joda.time.Instant; -import org.nd4j.shade.guava.base.Stopwatch; -import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; -import org.deeplearning4j.clustering.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.joda.time.Duration; import org.junit.Before; import org.junit.BeforeClass; @@ -30,8 +27,9 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; +import org.nd4j.shade.guava.base.Stopwatch; +import org.nd4j.shade.guava.primitives.Doubles; import org.nd4j.shade.guava.primitives.Floats; -import org.opencv.ml.KNearest; import java.util.ArrayList; import java.util.Arrays; @@ -48,6 +46,11 @@ import static org.junit.Assert.assertTrue; */ public class KDTreeTest extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + private KDTree kdTree; @BeforeClass @@ -174,7 +177,7 @@ public class KDTreeTest extends BaseDL4JTest { @Test public void testKNN() { int dimensions = 512; - int vectorsNo = 50000; + int vectorsNo = isIntegrationTests() ? 50000 : 1000; // make a KD-tree of dimension {#dimensions} Stopwatch stopwatch = Stopwatch.createStarted(); KDTree kdTree = new KDTree(dimensions); diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java index c9140942d..2f2619e78 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.clustering.kmeans; import lombok.val; import org.apache.commons.lang3.time.StopWatch; -import org.deeplearning4j.clustering.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.clustering.algorithm.Distance; import org.deeplearning4j.clustering.cluster.*; import org.junit.Ignore; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSHTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSHTest.java index be148c699..d9a041f0b 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSHTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSHTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.lsh; -import org.deeplearning4j.clustering.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.junit.After; import org.junit.Before; import org.junit.Ignore; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/quadtree/QuadTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/quadtree/QuadTreeTest.java index aa0e4db40..ec304b0c1 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/quadtree/QuadTreeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/quadtree/QuadTreeTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.quadtree; -import org.deeplearning4j.clustering.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPTreeTest.java index 30f4a841e..05bbb1cc9 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPTreeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPTreeTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.randomprojection; -import org.deeplearning4j.clustering.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.junit.Before; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPUtilsTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPUtilsTest.java index 96fbabd41..cb3af27bc 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPUtilsTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPUtilsTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.randomprojection; -import org.deeplearning4j.clustering.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java index f5ee19403..5034a124a 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java @@ -16,18 +16,15 @@ package org.deeplearning4j.clustering.sptree; -import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import org.apache.commons.lang3.time.StopWatch; -import org.deeplearning4j.clustering.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Before; -import org.junit.Ignore; import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; -import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import static org.junit.Assert.*; @@ -36,6 +33,11 @@ import static org.junit.Assert.*; */ public class SPTreeTest extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Before public void setUp() { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); @@ -90,13 +92,13 @@ public class SPTreeTest extends BaseDL4JTest { @Test //@Ignore public void testLargeTree() { - int num = 100000; + int num = isIntegrationTests() ? 100000 : 1000; StopWatch watch = new StopWatch(); watch.start(); INDArray arr = Nd4j.linspace(1, num, num, Nd4j.dataType()).reshape(num, 1); SpTree tree = new SpTree(arr); watch.stop(); - System.out.println("Tree created in " + watch); + System.out.println("Tree of size " + num + " created in " + watch); } } diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java index 12180a978..b67d5ccbf 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java @@ -19,7 +19,7 @@ package org.deeplearning4j.clustering.vptree; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang3.SerializationUtils; -import org.deeplearning4j.clustering.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.clustering.sptree.DataPoint; import org.junit.Ignore; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java index 5edb3926a..b7f254a63 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java @@ -18,7 +18,7 @@ package org.deeplearning4j.clustering.vptree; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.deeplearning4j.clustering.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.clustering.sptree.DataPoint; import org.joda.time.Duration; import org.junit.BeforeClass; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml index 219301c56..23d863cdc 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml @@ -59,6 +59,12 @@ compile + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/ChineseTokenizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/ChineseTokenizerTest.java index 0307e906e..aef6ed348 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/ChineseTokenizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/ChineseTokenizerTest.java @@ -17,6 +17,7 @@ package org.deeplearning4j.text.tokenization.tokenizer; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; @@ -37,7 +38,7 @@ import static org.junit.Assert.assertEquals; * */ @Slf4j -public class ChineseTokenizerTest { +public class ChineseTokenizerTest extends BaseDL4JTest { private final String toTokenize = "青山绿水和伟大的科学家让世界更美好和平"; private final String[] expect = {"青山绿水", "和", "伟大", "的", "科学家", "让", "世界", "更", "美好", "和平"}; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml index beeb07d34..a4fea6b07 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml @@ -61,6 +61,13 @@ org.slf4j slf4j-api + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/CommonCornerCasesTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/CommonCornerCasesTest.java index 43bdf1c7a..1d23baab6 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/CommonCornerCasesTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/CommonCornerCasesTest.java @@ -32,11 +32,13 @@ */ package com.atilika.kuromoji; +import org.deeplearning4j.BaseDL4JTest; + import java.util.Arrays; import static com.atilika.kuromoji.TestUtils.assertTokenSurfacesEquals; -public class CommonCornerCasesTest { +public class CommonCornerCasesTest extends BaseDL4JTest { public static void testPunctuation(TokenizerBase tokenizer) { String gerryNoHanaNoHanashi = "僕の鼻はちょっと\r\n長いよ。"; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/buffer/StringValueMapBufferTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/buffer/StringValueMapBufferTest.java index 7e60f4544..265f4faf4 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/buffer/StringValueMapBufferTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/buffer/StringValueMapBufferTest.java @@ -32,13 +32,14 @@ */ package com.atilika.kuromoji.buffer; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import java.util.TreeMap; import static org.junit.Assert.assertEquals; -public class StringValueMapBufferTest { +public class StringValueMapBufferTest extends BaseDL4JTest { @Test public void testInsertIntoMap() throws Exception { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/CharacterDefinitionsCompilerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/CharacterDefinitionsCompilerTest.java index 947007eeb..2e077eef6 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/CharacterDefinitionsCompilerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/CharacterDefinitionsCompilerTest.java @@ -35,6 +35,7 @@ package com.atilika.kuromoji.compile; import com.atilika.kuromoji.dict.CharacterDefinitions; import com.atilika.kuromoji.io.IntegerArrayIO; import com.atilika.kuromoji.io.StringArrayIO; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Before; import org.junit.Test; @@ -46,7 +47,7 @@ import java.util.Map; import static org.junit.Assert.*; -public class CharacterDefinitionsCompilerTest { +public class CharacterDefinitionsCompilerTest extends BaseDL4JTest { private File charDef; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/ConnectionCostsCompilerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/ConnectionCostsCompilerTest.java index 8bac040d4..516535e42 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/ConnectionCostsCompilerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/ConnectionCostsCompilerTest.java @@ -34,6 +34,7 @@ package com.atilika.kuromoji.compile; import com.atilika.kuromoji.dict.ConnectionCosts; import com.atilika.kuromoji.io.ByteBufferIO; +import org.deeplearning4j.BaseDL4JTest; import org.junit.BeforeClass; import org.junit.Test; @@ -43,7 +44,7 @@ import java.nio.charset.StandardCharsets; import static org.junit.Assert.assertEquals; -public class ConnectionCostsCompilerTest { +public class ConnectionCostsCompilerTest extends BaseDL4JTest { private static ConnectionCosts connectionCosts; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/TokenInfoBufferCompilerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/TokenInfoBufferCompilerTest.java index 69f51d9c3..abda8710d 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/TokenInfoBufferCompilerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/TokenInfoBufferCompilerTest.java @@ -34,6 +34,7 @@ package com.atilika.kuromoji.compile; import com.atilika.kuromoji.buffer.BufferEntry; import com.atilika.kuromoji.buffer.TokenInfoBuffer; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import java.io.File; @@ -47,7 +48,7 @@ import java.util.Map; import static org.junit.Assert.assertEquals; -public class TokenInfoBufferCompilerTest { +public class TokenInfoBufferCompilerTest extends BaseDL4JTest { @Test public void testReadAndWriteFromBuffer() throws Exception { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/UnknownDictionaryCompilerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/UnknownDictionaryCompilerTest.java index 14855bd7a..3156e47a3 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/UnknownDictionaryCompilerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/UnknownDictionaryCompilerTest.java @@ -36,6 +36,7 @@ import com.atilika.kuromoji.dict.CharacterDefinitions; import com.atilika.kuromoji.dict.UnknownDictionary; import com.atilika.kuromoji.io.IntegerArrayIO; import com.atilika.kuromoji.io.StringArrayIO; +import org.deeplearning4j.BaseDL4JTest; import org.junit.BeforeClass; import org.junit.Test; @@ -45,7 +46,7 @@ import java.util.Map; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -public class UnknownDictionaryCompilerTest { +public class UnknownDictionaryCompilerTest extends BaseDL4JTest { private static UnknownDictionary unknownDictionary; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/WordIdMapCompilerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/WordIdMapCompilerTest.java index 5b374d338..6edb8f541 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/WordIdMapCompilerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/compile/WordIdMapCompilerTest.java @@ -33,6 +33,7 @@ package com.atilika.kuromoji.compile; import com.atilika.kuromoji.buffer.WordIdMap; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import java.io.*; @@ -40,7 +41,7 @@ import java.util.Arrays; import static org.junit.Assert.assertEquals; -public class WordIdMapCompilerTest { +public class WordIdMapCompilerTest extends BaseDL4JTest { @Test public void testGrowableArray() { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/dict/InsertedDictionaryTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/dict/InsertedDictionaryTest.java index 0144a3f3a..eae973831 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/dict/InsertedDictionaryTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/dict/InsertedDictionaryTest.java @@ -32,12 +32,13 @@ */ package com.atilika.kuromoji.dict; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -public class InsertedDictionaryTest { +public class InsertedDictionaryTest extends BaseDL4JTest { @Test public void testFeatureSize() { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/dict/UserDictionaryTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/dict/UserDictionaryTest.java index 6da0ddf20..cb6b503c0 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/dict/UserDictionaryTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/dict/UserDictionaryTest.java @@ -32,6 +32,7 @@ */ package com.atilika.kuromoji.dict; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import org.nd4j.linalg.io.ClassPathResource; @@ -43,7 +44,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class UserDictionaryTest { +public class UserDictionaryTest extends BaseDL4JTest { @Test public void testLookup() throws IOException { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/MultiThreadedTokenizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/MultiThreadedTokenizerTest.java index 4f766b4a9..05f0f9f8f 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/MultiThreadedTokenizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/MultiThreadedTokenizerTest.java @@ -32,6 +32,7 @@ */ package com.atilika.kuromoji.ipadic; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import org.nd4j.linalg.io.ClassPathResource; @@ -39,7 +40,7 @@ import java.io.IOException; import static com.atilika.kuromoji.TestUtils.assertMultiThreadedTokenizedStreamEquals; -public class MultiThreadedTokenizerTest { +public class MultiThreadedTokenizerTest extends BaseDL4JTest { @Test public void testMultiThreadedBocchan() throws IOException, InterruptedException { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/RandomizedInputTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/RandomizedInputTest.java index 926824a1f..c30d6ae40 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/RandomizedInputTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/RandomizedInputTest.java @@ -45,19 +45,19 @@ public class RandomizedInputTest extends RandomizedTest { private Tokenizer tokenizer = new Tokenizer(); @Test - @Repeat(iterations = 50) + @Repeat(iterations = 10) public void testRandomizedUnicodeInput() { assertCanTokenizeString(randomUnicodeOfLength(LENGTH), tokenizer); } @Test - @Repeat(iterations = 50) + @Repeat(iterations = 10) public void testRandomizedRealisticUnicodeInput() { assertCanTokenizeString(randomRealisticUnicodeOfLength(LENGTH), tokenizer); } @Test - @Repeat(iterations = 50) + @Repeat(iterations = 10) public void testRandomizedAsciiInput() { assertCanTokenizeString(randomAsciiOfLength(LENGTH), tokenizer); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/SearchTokenizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/SearchTokenizerTest.java index e5f57811e..4883ba0b3 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/SearchTokenizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/SearchTokenizerTest.java @@ -33,6 +33,7 @@ package com.atilika.kuromoji.ipadic; import com.atilika.kuromoji.TokenizerBase.Mode; +import org.deeplearning4j.BaseDL4JTest; import org.junit.BeforeClass; import org.junit.Test; import org.nd4j.linalg.io.ClassPathResource; @@ -47,7 +48,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class SearchTokenizerTest { +public class SearchTokenizerTest extends BaseDL4JTest { private static Tokenizer tokenizer; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/TokenizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/TokenizerTest.java index d4fc66849..ce8bcb206 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/TokenizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/TokenizerTest.java @@ -33,6 +33,7 @@ package com.atilika.kuromoji.ipadic; import com.atilika.kuromoji.CommonCornerCasesTest; +import org.deeplearning4j.BaseDL4JTest; import org.junit.BeforeClass; import org.junit.Test; import org.nd4j.linalg.io.ClassPathResource; @@ -48,7 +49,7 @@ import java.util.List; import static com.atilika.kuromoji.TestUtils.*; import static org.junit.Assert.*; -public class TokenizerTest { +public class TokenizerTest extends BaseDL4JTest { private static Tokenizer tokenizer; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/UserDictionaryTokenizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/UserDictionaryTokenizerTest.java index 586453aa3..204693e31 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/UserDictionaryTokenizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/ipadic/UserDictionaryTokenizerTest.java @@ -32,6 +32,7 @@ */ package com.atilika.kuromoji.ipadic; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Ignore; import org.junit.Test; @@ -45,7 +46,7 @@ import java.util.List; import static com.atilika.kuromoji.TestUtils.assertTokenSurfacesEquals; import static org.junit.Assert.assertEquals; -public class UserDictionaryTokenizerTest { +public class UserDictionaryTokenizerTest extends BaseDL4JTest { private String userDictionary = "" + "クロ,クロ,クロ,カスタム名詞\n" + "真救世主,真救世主,シンキュウセイシュ,カスタム名詞\n" + "真救世主伝説,真救世主伝説,シンキュウセイシュデンセツ,カスタム名詞\n" + "北斗の拳,北斗の拳,ホクトノケン,カスタム名詞"; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/DoubleArrayTrieTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/DoubleArrayTrieTest.java index cac7d2116..3f4e5763e 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/DoubleArrayTrieTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/DoubleArrayTrieTest.java @@ -32,6 +32,7 @@ */ package com.atilika.kuromoji.trie; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import java.io.*; @@ -39,7 +40,7 @@ import java.io.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class DoubleArrayTrieTest { +public class DoubleArrayTrieTest extends BaseDL4JTest { @Test public void testSparseTrie() throws IOException { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/NodeTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/NodeTest.java index d36f5381e..474c073f2 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/NodeTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/NodeTest.java @@ -32,12 +32,13 @@ */ package com.atilika.kuromoji.trie; +import org.deeplearning4j.BaseDL4JTest; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.assertEquals; -public class NodeTest { +public class NodeTest extends BaseDL4JTest { @BeforeClass public static void setUpBeforeClass() throws Exception {} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/PatriciaTrieTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/PatriciaTrieTest.java index 7bfe62ac0..1f4bec6ad 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/PatriciaTrieTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/PatriciaTrieTest.java @@ -32,13 +32,14 @@ */ package com.atilika.kuromoji.trie; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import java.util.*; import static org.junit.Assert.*; -public class PatriciaTrieTest { +public class PatriciaTrieTest extends BaseDL4JTest { @Test public void testRomaji() { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/TrieTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/TrieTest.java index 976903a9d..a27131eb8 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/TrieTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/trie/TrieTest.java @@ -33,11 +33,12 @@ package com.atilika.kuromoji.trie; import com.atilika.kuromoji.trie.Trie.Node; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import static org.junit.Assert.*; -public class TrieTest { +public class TrieTest extends BaseDL4JTest { @Test public void testGetRoot() { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/util/DictionaryEntryLineParserTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/util/DictionaryEntryLineParserTest.java index 41bb0bc52..f9ff1c060 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/util/DictionaryEntryLineParserTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/com/atilika/kuromoji/util/DictionaryEntryLineParserTest.java @@ -32,6 +32,7 @@ */ package com.atilika.kuromoji.util; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import java.util.Arrays; @@ -39,7 +40,7 @@ import java.util.Arrays; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -public class DictionaryEntryLineParserTest { +public class DictionaryEntryLineParserTest extends BaseDL4JTest { private DictionaryEntryLineParser parser = new DictionaryEntryLineParser(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/JapaneseTokenizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/JapaneseTokenizerTest.java index 5e5daf383..849009bdd 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/JapaneseTokenizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/JapaneseTokenizerTest.java @@ -16,6 +16,7 @@ package org.deeplearning4j.text.tokenization.tokenizer; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.text.tokenization.tokenizerfactory.JapaneseTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.junit.Test; @@ -25,7 +26,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class JapaneseTokenizerTest { +public class JapaneseTokenizerTest extends BaseDL4JTest { private String toTokenize = "黒い瞳の綺麗な女の子"; private String[] expect = {"黒い", "瞳", "の", "綺麗", "な", "女の子"}; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml index e11e9044f..c0fcdb84a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml @@ -54,6 +54,12 @@ deeplearning4j-nlp ${project.version} + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/KoreanTokenizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/KoreanTokenizerTest.java index c60243b75..275515968 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/KoreanTokenizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/KoreanTokenizerTest.java @@ -16,6 +16,7 @@ package org.deeplearning4j.text.tokenization.tokenizer; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.text.tokenization.tokenizerfactory.KoreanTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.junit.Test; @@ -25,7 +26,7 @@ import static org.junit.Assert.assertEquals; /** * Created by kepricon on 16. 10. 24. */ -public class KoreanTokenizerTest { +public class KoreanTokenizerTest extends BaseDL4JTest { @Test public void testKoreanTokenizer() throws Exception { String toTokenize = "세계 최초의 상용 수준 오픈소스 딥러닝 라이브러리입니다"; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/PerformanceTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/PerformanceTests.java index b9e0c35dc..c9fece977 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/PerformanceTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/PerformanceTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.text.tokenization.tokenizer; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW; import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; import org.deeplearning4j.models.word2vec.VocabWord; @@ -32,7 +33,7 @@ import org.junit.Test; * @author raver119@gmail.com */ @Slf4j -public class PerformanceTests { +public class PerformanceTests extends BaseDL4JTest { @Ignore diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml index 39eda5e50..7aa6090e1 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml @@ -32,6 +32,23 @@ UTF-8 + + + + + org.apache.maven.plugins + maven-compiler-plugin + + 1.8 + 1.8 + 1.8 + 1.8 + + + + + + org.cleartk @@ -72,6 +89,13 @@ ${project.version} test + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/BaseDL4JTest.java deleted file mode 100644 index 05d0957fb..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/BaseDL4JTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java index 69eae7307..7807ff711 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java @@ -16,6 +16,7 @@ package org.deeplearning4j.models; +import org.junit.rules.Timeout; import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; @@ -75,6 +76,9 @@ public class WordVectorSerializerTest extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Rule + public Timeout timeout = Timeout.seconds(300); + private File textFile, binaryFile, textFile2; private File fastTextRaw, fastTextZip, fastTextGzip; String pathToWriteto; @@ -402,11 +406,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest { double simD = arraysSimilarity(day1, day2); double simN = arraysSimilarity(night1, night2); - logger.info("Vec1 day: " + day1); - logger.info("Vec2 day: " + day2); +// logger.info("Vec1 day: " + day1); +// logger.info("Vec2 day: " + day2); - logger.info("Vec1 night: " + night1); - logger.info("Vec2 night: " + night2); +// logger.info("Vec1 night: " + night1); +// logger.info("Vec2 night: " + night2); logger.info("Day/day cross-model similarity: " + simD); logger.info("Night/night cross-model similarity: " + simN); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/embeddings/loader/VectorsConfigurationTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/embeddings/loader/VectorsConfigurationTest.java index 1546ead8f..8f13ae3fd 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/embeddings/loader/VectorsConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/embeddings/loader/VectorsConfigurationTest.java @@ -62,7 +62,7 @@ public class VectorsConfigurationTest extends BaseDL4JTest { assertEquals(configuration, configuration2); } - @Test + @Test(timeout = 300000) public void testFromW2V() throws Exception { VectorsConfiguration configuration = new VectorsConfiguration(); configuration.setHugeModelExpected(true); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java index 736998484..e50a95443 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java @@ -16,6 +16,11 @@ package org.deeplearning4j.models.word2vec; +import org.apache.commons.io.IOUtils; +import org.apache.commons.io.LineIterator; +import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator; +import org.junit.Rule; +import org.junit.rules.Timeout; import org.nd4j.shade.guava.primitives.Doubles; import org.nd4j.shade.guava.primitives.Ints; import lombok.val; @@ -49,8 +54,8 @@ import org.nd4j.resources.Resources; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.File; -import java.io.IOException; +import java.io.*; +import java.nio.charset.StandardCharsets; import java.util.*; import static org.junit.Assert.*; @@ -68,6 +73,9 @@ public class Word2VecTests extends BaseDL4JTest { private String pathToWriteto; private WordVectors googleModel; + @Rule + public Timeout timeout = Timeout.seconds(300); + @Before public void before() throws Exception { File googleModelTextFile = new ClassPathResource("word2vecserialization/google_news_30.txt").getFile(); @@ -180,7 +188,12 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testWord2VecMultiEpoch() throws Exception { - SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); + SentenceIterator iter; + if(isIntegrationTests()){ + iter = new BasicLineIterator(inputFile.getAbsolutePath()); + } else { + iter = new CollectionSentenceIterator(firstNLines(inputFile, 50000)); + } TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); @@ -384,7 +397,12 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testW2VnegativeOnRestore() throws Exception { // Strip white space before and after for each line - SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); + SentenceIterator iter; + if(isIntegrationTests()){ + iter = new BasicLineIterator(inputFile.getAbsolutePath()); + } else { + iter = new CollectionSentenceIterator(firstNLines(inputFile, 300)); + } // Split on white spaces in the line to get words TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); @@ -486,7 +504,12 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void orderIsCorrect_WhenParallelized() throws Exception { // Strip white space before and after for each line - SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); + SentenceIterator iter; + if(isIntegrationTests()){ + iter = new BasicLineIterator(inputFile.getAbsolutePath()); + } else { + iter = new CollectionSentenceIterator(firstNLines(inputFile, 300)); + } // Split on white spaces in the line to get words TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); @@ -505,9 +528,10 @@ public class Word2VecTests extends BaseDL4JTest { System.out.println(vec.getVocab().numWords()); val words = vec.getVocab().words(); - for (val word : words) { - System.out.println(word); - } + assertTrue(words.size() > 0); +// for (val word : words) { +// System.out.println(word); +// } } @Test @@ -750,7 +774,16 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void weightsNotUpdated_WhenLocked() throws Exception { - SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); + boolean isIntegration = isIntegrationTests(); + SentenceIterator iter; + SentenceIterator iter2; + if(isIntegration){ + iter = new BasicLineIterator(inputFile); + iter2 = new BasicLineIterator(inputFile2.getAbsolutePath()); + } else { + iter = new CollectionSentenceIterator(firstNLines(inputFile, 300)); + iter2 = new CollectionSentenceIterator(firstNLines(inputFile2, 300)); + } Word2Vec vec1 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100) .stopWords(new ArrayList()).seed(42).learningRate(0.025).minLearningRate(0.001) @@ -762,13 +795,12 @@ public class Word2VecTests extends BaseDL4JTest { vec1.fit(); - iter = new BasicLineIterator(inputFile2.getAbsolutePath()); Word2Vec vec2 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(32).layerSize(100) .stopWords(new ArrayList()).seed(32).learningRate(0.021).minLearningRate(0.001) .sampling(0).elementsLearningAlgorithm(new SkipGram()) .epochs(1).windowSize(5).allowParallelTokenization(true) .workers(1) - .iterate(iter) + .iterate(iter2) .intersectModel(vec1, true) .modelUtils(new BasicModelUtils()).build(); @@ -856,6 +888,22 @@ public class Word2VecTests extends BaseDL4JTest { } System.out.print("\n"); } - // + + public static List firstNLines(File f, int n){ + List lines = new ArrayList<>(); + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8); + try{ + for( int i=0; i0.4 + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/BaseDL4JTest.java deleted file mode 100644 index 05d0957fb..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/BaseDL4JTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java index 613090d8a..c99cb3b9a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java @@ -46,6 +46,11 @@ import static org.junit.Assert.assertEquals; @Slf4j public class TsneTest extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 60000L; + } + @Rule public TemporaryFolder testDir = new TemporaryFolder(); @@ -53,103 +58,102 @@ public class TsneTest extends BaseDL4JTest { public void testSimple() throws Exception { //Simple sanity check - for (boolean syntheticData : new boolean[]{false, true}) { - for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) { - log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData); + for( int test=0; test <=1; test++){ + boolean syntheticData = test == 1; + WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED; + log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData); - //STEP 1: Initialization - int iterations = 300; - //create an n-dimensional array of doubles - Nd4j.setDataType(DataType.DOUBLE); - List cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words + //STEP 1: Initialization + int iterations = 50; + //create an n-dimensional array of doubles + Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); + List cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words - //STEP 2: Turn text input into a list of words - INDArray weights; - if(syntheticData){ - weights = Nd4j.rand(5000, 200); - } else { - log.info("Load & Vectorize data...."); - File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file - //Get the data of all unique word vectors - Pair vectors = WordVectorSerializer.loadTxt(wordFile); - VocabCache cache = vectors.getSecond(); - weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list + //STEP 2: Turn text input into a list of words + INDArray weights; + if(syntheticData){ + weights = Nd4j.rand(250, 200); + } else { + log.info("Load & Vectorize data...."); + File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file + //Get the data of all unique word vectors + Pair vectors = WordVectorSerializer.loadTxt(wordFile); + VocabCache cache = vectors.getSecond(); + weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list - for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list - cacheList.add(cache.wordAtIndex(i)); - } - - //STEP 3: build a dual-tree tsne to use later - log.info("Build model...."); - BarnesHutTsne tsne = new BarnesHutTsne.Builder() - .setMaxIter(iterations) - .theta(0.5) - .normalize(false) - .learningRate(500) - .useAdaGrad(false) - .workspaceMode(wsm) - .build(); - - - //STEP 4: establish the tsne values and save them to a file - log.info("Store TSNE Coordinates for Plotting...."); - File outDir = testDir.newFolder(); - tsne.fit(weights); - tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); + for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list + cacheList.add(cache.wordAtIndex(i)); } + + //STEP 3: build a dual-tree tsne to use later + log.info("Build model...."); + BarnesHutTsne tsne = new BarnesHutTsne.Builder() + .setMaxIter(iterations) + .theta(0.5) + .normalize(false) + .learningRate(500) + .useAdaGrad(false) + .workspaceMode(wsm) + .build(); + + + //STEP 4: establish the tsne values and save them to a file + log.info("Store TSNE Coordinates for Plotting...."); + File outDir = testDir.newFolder(); + tsne.fit(weights); + tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); } } - //Elapsed time : 01:01:57.988 @Test public void testPerformance() throws Exception { StopWatch watch = new StopWatch(); watch.start(); - for (boolean syntheticData : new boolean[]{false, true}) { - for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) { - log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData); + for( int test=0; test <=1; test++){ + boolean syntheticData = test == 1; + WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED; + log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData); - //STEP 1: Initialization - int iterations = 100; - //create an n-dimensional array of doubles - Nd4j.setDataType(DataType.DOUBLE); - List cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words + //STEP 1: Initialization + int iterations = 50; + //create an n-dimensional array of doubles + Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); + List cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words - //STEP 2: Turn text input into a list of words - INDArray weights; - if(syntheticData){ - weights = Nd4j.rand(5000, 20); - } else { - log.info("Load & Vectorize data...."); - File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file - //Get the data of all unique word vectors - Pair vectors = WordVectorSerializer.loadTxt(wordFile); - VocabCache cache = vectors.getSecond(); - weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list + //STEP 2: Turn text input into a list of words + INDArray weights; + if(syntheticData){ + weights = Nd4j.rand(DataType.FLOAT, 250, 20); + } else { + log.info("Load & Vectorize data...."); + File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file + //Get the data of all unique word vectors + Pair vectors = WordVectorSerializer.loadTxt(wordFile); + VocabCache cache = vectors.getSecond(); + weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list - for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list - cacheList.add(cache.wordAtIndex(i)); - } - - //STEP 3: build a dual-tree tsne to use later - log.info("Build model...."); - BarnesHutTsne tsne = new BarnesHutTsne.Builder() - .setMaxIter(iterations) - .theta(0.5) - .normalize(false) - .learningRate(500) - .useAdaGrad(false) - .workspaceMode(wsm) - .build(); - - - //STEP 4: establish the tsne values and save them to a file - log.info("Store TSNE Coordinates for Plotting...."); - File outDir = testDir.newFolder(); - tsne.fit(weights); - tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); + for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list + cacheList.add(cache.wordAtIndex(i)); } + + //STEP 3: build a dual-tree tsne to use later + log.info("Build model...."); + BarnesHutTsne tsne = new BarnesHutTsne.Builder() + .setMaxIter(iterations) + .theta(0.5) + .normalize(false) + .learningRate(500) + .useAdaGrad(false) + .workspaceMode(wsm) + .build(); + + + //STEP 4: establish the tsne values and save them to a file + log.info("Store TSNE Coordinates for Plotting...."); + File outDir = testDir.newFolder(); + tsne.fit(weights); + tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); } watch.stop(); System.out.println("Elapsed time : " + watch); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java index d7a0b7934..97a8821c6 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java @@ -52,7 +52,7 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { } - @Test + @Test(timeout = 300000) public void testConsumeOnEqualVocabs() throws Exception { TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); @@ -99,7 +99,7 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { } - @Test + @Test(timeout = 300000) public void testConsumeOnNonEqualVocabs() throws Exception { TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java index 731a9cd60..8d53e8c1a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java @@ -12,6 +12,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.primitives.Pair; import org.nd4j.resources.Resources; @@ -27,6 +28,9 @@ import static org.junit.Assert.assertEquals; @Slf4j public class FastTextTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt"); private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin"); private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin"); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java index 4750e8bfb..95cd4e9a6 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java @@ -20,6 +20,8 @@ package org.deeplearning4j.models.paragraphvectors; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.apache.commons.io.IOUtils; +import org.apache.commons.io.LineIterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW; import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils; @@ -27,6 +29,7 @@ import org.deeplearning4j.models.sequencevectors.sequence.Sequence; import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer; import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator; import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.ParallelTransformerIterator; +import org.deeplearning4j.text.sentenceiterator.*; import org.junit.Rule; import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.io.ClassPathResource; @@ -46,10 +49,6 @@ import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator; import org.deeplearning4j.text.documentiterator.LabelAwareIterator; import org.deeplearning4j.text.documentiterator.LabelledDocument; import org.deeplearning4j.text.documentiterator.LabelsSource; -import org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator; -import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; -import org.deeplearning4j.text.sentenceiterator.FileSentenceIterator; -import org.deeplearning4j.text.sentenceiterator.SentenceIterator; import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; @@ -66,8 +65,8 @@ import org.nd4j.resources.Resources; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.File; -import java.io.IOException; +import java.io.*; +import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.atomic.AtomicLong; @@ -79,6 +78,11 @@ import static org.junit.Assert.*; @Slf4j public class ParagraphVectorsTest extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 240000; + } + @Rule public TemporaryFolder testDir = new TemporaryFolder(); @@ -367,7 +371,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { LabelsSource source = new LabelsSource("DOC_"); - ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(3) + ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(1) .layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter) .trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0) .useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true) @@ -420,6 +424,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest { @Test(timeout = 300000) public void testParagraphVectorsDBOW() throws Exception { + skipUnlessIntegrationTests(); + File file = Resources.asFile("/big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(file); @@ -652,7 +658,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { } } - @Test(timeout = 300000) + @Test public void testIterator() throws IOException { val folder_labeled = testDir.newFolder(); val folder_unlabeled = testDir.newFolder(); @@ -667,7 +673,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { SentenceIterator iter = new BasicLineIterator(resource_sentences); int i = 0; - for (; i < 10000; ++i) { + for (; i < 10; ++i) { int j = 0; int labels = 0; int words = 0; @@ -716,7 +722,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); - Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(3) + Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(1) .learningRate(0.025).layerSize(150).minLearningRate(0.001) .elementsLearningAlgorithm(new SkipGram()).useHierarchicSoftmax(true).windowSize(5) .allowParallelTokenization(true) @@ -1004,7 +1010,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); - Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3) + Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(1) .learningRate(0.025).layerSize(150).minLearningRate(0.001) .elementsLearningAlgorithm(new SkipGram()).useHierarchicSoftmax(true).windowSize(5) .iterate(iter).tokenizerFactory(t).build(); @@ -1146,8 +1152,27 @@ public class ParagraphVectorsTest extends BaseDL4JTest { @Test(timeout = 300000) public void testDoubleFit() throws Exception { + boolean isIntegration = isIntegrationTests(); File resource = Resources.asFile("/big/raw_sentences.txt"); - SentenceIterator iter = new BasicLineIterator(resource); + SentenceIterator iter; + if(isIntegration){ + iter = new BasicLineIterator(resource); + } else { + List lines = new ArrayList<>(); + try(InputStream is = new BufferedInputStream(new FileInputStream(resource))){ + LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8); + try{ + for( int i=0; i<500 && lineIter.hasNext(); i++ ){ + lines.add(lineIter.next()); + } + } finally { + lineIter.close(); + } + } + + iter = new CollectionSentenceIterator(lines); + } + TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java index c8674e630..c031d99ea 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java @@ -55,7 +55,7 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { } - @Test + @Test(timeout = 300000) public void hasNext() throws Exception { SentenceIterator iterator = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")); @@ -77,7 +77,7 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { assertEquals(97162, cnt); } - @Test + @Test(timeout = 300000) public void testSpeedComparison1() throws Exception { SentenceIterator iterator = new MutipleEpochsSentenceIterator( new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")), 25); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java index fc10a592a..ae2bf83c7 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java @@ -82,7 +82,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest { assertEquals(neighbours, nearestWords.size()); } - @Test + @Test(timeout = 300000) public void testUnkSerialization_1() throws Exception { val inputFile = Resources.asFile("big/raw_sentences.txt"); @@ -142,7 +142,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest { } - @Test + @Test(timeout = 300000) public void testW2VEmbeddingLayerInit() throws Exception { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java index 4a77d6807..9e8e89d39 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java @@ -47,11 +47,17 @@ import static org.junit.Assert.assertArrayEquals; */ public class Word2VecDataSetIteratorTest extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 60000L; + } + /** * Basically all we want from this test - being able to finish without exceptions. */ @Test public void testIterator1() throws Exception { + File inputFile = Resources.asFile("big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); @@ -72,10 +78,14 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest { Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1); INDArray array = iterator.next().getFeatures(); + int count = 0; while (iterator.hasNext()) { DataSet ds = iterator.next(); assertArrayEquals(array.shape(), ds.getFeatures().shape()); + + if(!isIntegrationTests() && count++ > 20) + break; //raw_sentences.txt is 2.81 MB, takes quite some time to process. We'll only first 20 minibatches when doing unit tests } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java index 84b8d3c38..430c21492 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java @@ -22,6 +22,7 @@ import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.junit.Rule; import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.io.ClassPathResource; import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator; import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator; @@ -43,6 +44,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; +import java.sql.Time; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; @@ -53,6 +55,9 @@ import static org.junit.Assert.*; */ public class VocabConstructorTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + protected static final Logger log = LoggerFactory.getLogger(VocabConstructorTest.class); TokenizerFactory t = new DefaultTokenizerFactory(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java index 7c8ea3ba7..dbf6c7aa4 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java @@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals; * @author raver119@gmail.com */ public class AsyncLabelAwareIteratorTest extends BaseDL4JTest { - @Test + @Test(timeout = 300000) public void nextDocument() throws Exception { SentenceIterator sentence = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")); BasicLabelAwareIterator backed = new BasicLabelAwareIterator.Builder(sentence).build(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java index 1696226d3..984a098cc 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java @@ -17,6 +17,8 @@ package org.deeplearning4j.text.documentiterator; import org.deeplearning4j.BaseDL4JTest; +import org.junit.Rule; +import org.junit.rules.Timeout; import org.nd4j.linalg.io.ClassPathResource; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; @@ -33,6 +35,9 @@ import static org.junit.Assert.assertEquals; */ public class BasicLabelAwareIteratorTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + @Before public void setUp() throws Exception { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java index cd6fe3449..775e18d3f 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java @@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals; */ public class AggregatingSentenceIteratorTest extends BaseDL4JTest { - @Test + @Test(timeout = 300000) public void testHasNext() throws Exception { File file = Resources.asFile("/big/raw_sentences.txt"); BasicLineIterator iterator = new BasicLineIterator(file); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java index e6aaa338e..0f01937ed 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java @@ -17,6 +17,8 @@ package org.deeplearning4j.text.sentenceiterator; import org.deeplearning4j.BaseDL4JTest; +import org.junit.Rule; +import org.junit.rules.Timeout; import org.nd4j.linalg.io.ClassPathResource; import org.junit.Before; import org.junit.Test; @@ -32,6 +34,9 @@ import static org.junit.Assert.assertEquals; */ public class BasicLineIteratorTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + @Before public void setUp() throws Exception { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java index b187caf8d..1a3c215aa 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java @@ -27,7 +27,7 @@ import static org.junit.Assert.assertEquals; * @author raver119@gmail.com */ public class MutipleEpochsSentenceIteratorTest extends BaseDL4JTest { - @Test + @Test(timeout = 300000) public void hasNext() throws Exception { SentenceIterator iterator = new MutipleEpochsSentenceIterator( new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")), 100); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java index 414f1454b..f0f6f1c54 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java @@ -17,6 +17,8 @@ package org.deeplearning4j.text.sentenceiterator; import org.deeplearning4j.BaseDL4JTest; +import org.junit.Rule; +import org.junit.rules.Timeout; import org.nd4j.linalg.io.ClassPathResource; import org.junit.Test; import org.nd4j.resources.Resources; @@ -33,6 +35,9 @@ import static org.junit.Assert.assertTrue; */ public class PrefetchingSentenceIteratorTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + protected static final Logger log = LoggerFactory.getLogger(PrefetchingSentenceIteratorTest.class); @Test diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java index 4b78e51a2..80570ae54 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java @@ -205,7 +205,7 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { } - @Test + @Test(timeout = 300000) public void testBertWordPieceTokenizer10() throws Exception { File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt"); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8); diff --git a/deeplearning4j/deeplearning4j-nn/pom.xml b/deeplearning4j/deeplearning4j-nn/pom.xml index c1ff45a61..e92372fc8 100644 --- a/deeplearning4j/deeplearning4j-nn/pom.xml +++ b/deeplearning4j/deeplearning4j-nn/pom.xml @@ -117,6 +117,12 @@ test + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java index d15e961b7..3cd169c59 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java @@ -16,8 +16,9 @@ package org.deeplearning4j.gradientcheck; +import lombok.*; +import lombok.experimental.Accessors; import lombok.extern.slf4j.Slf4j; -import lombok.val; import org.deeplearning4j.nn.api.Model; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.exception.ND4JArraySizeException; @@ -113,6 +114,52 @@ public class GradientCheckUtil { } } + public enum PrintMode { + ALL, + ZEROS, + FAILURES_ONLY + } + + @Accessors(fluent = true) + @Data + @NoArgsConstructor + public static class MLNConfig { + private MultiLayerNetwork net; + private INDArray input; + private INDArray labels; + private INDArray inputMask; + private INDArray labelMask; + private double epsilon = 1e-6; + private double maxRelError = 1e-3; + private double minAbsoluteError = 1e-8; + private PrintMode print = PrintMode.ZEROS; + private boolean exitOnFirstError = false; + private boolean subset; + private int maxPerParam; + private Set excludeParams; + private Consumer callEachIter; + } + + @Accessors(fluent = true) + @Data + @NoArgsConstructor + public static class GraphConfig { + private ComputationGraph net; + private INDArray[] inputs; + private INDArray[] labels; + private INDArray[] inputMask; + private INDArray[] labelMask; + private double epsilon = 1e-6; + private double maxRelError = 1e-3; + private double minAbsoluteError = 1e-8; + private PrintMode print = PrintMode.ZEROS; + private boolean exitOnFirstError = false; + private boolean subset; + private int maxPerParam; + private Set excludeParams; + private Consumer callEachIter; + } + /** * Check backprop gradients for a MultiLayerNetwork. * @param mln MultiLayerNetwork to test. This must be initialized. @@ -127,46 +174,18 @@ public class GradientCheckUtil { * @param labels Labels/targets to use to calculate backprop gradient. May be mini-batch data. * @return true if gradients are passed, false otherwise. */ + @Deprecated public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels) { - return checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, labels, null, null); - } - - public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, - INDArray labels, Set excludeParams) { - return checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, labels, null, null, - false, -1, excludeParams, (Integer)null); - } - - public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, - INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask) { - return checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, - input, labels, inputMask, labelMask, false, -1); - } - - public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, - INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask, - boolean subset, int maxPerParam) { - return checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, - labels, inputMask, labelMask, subset, maxPerParam, null); - } - - public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, - INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask, - boolean subset, int maxPerParam, Set excludeParams) { - return checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, - labels, inputMask, labelMask, subset, maxPerParam, excludeParams, (Consumer)null); + double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels) { + return checkGradients(new MLNConfig().net(mln).epsilon(epsilon).maxRelError(maxRelError).minAbsoluteError(minAbsoluteError).print(PrintMode.FAILURES_ONLY) + .exitOnFirstError(exitOnFirstError).input(input).labels(labels)); } + @Deprecated public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask, boolean subset, int maxPerParam, Set excludeParams, final Integer rngSeedResetEachIter) { - Consumer c = null; if(rngSeedResetEachIter != null){ c = new Consumer() { @@ -177,21 +196,18 @@ public class GradientCheckUtil { }; } - return checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, - labels, inputMask, labelMask, subset, maxPerParam, excludeParams, c); + return checkGradients(new MLNConfig().net(mln).epsilon(epsilon).maxRelError(maxRelError).minAbsoluteError(minAbsoluteError).print(PrintMode.FAILURES_ONLY) + .exitOnFirstError(exitOnFirstError).input(input).labels(labels).inputMask(inputMask).labelMask(labelMask).subset(subset).maxPerParam(maxPerParam).excludeParams(excludeParams).callEachIter(c)); } - public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, - INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask, - boolean subset, int maxPerParam, Set excludeParams, Consumer callEachIter) { + public static boolean checkGradients(MLNConfig c){ //Basic sanity checks on input: - if (epsilon <= 0.0 || epsilon > 0.1) + if (c.epsilon <= 0.0 || c.epsilon > 0.1) throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so"); - if (maxRelError <= 0.0 || maxRelError > 0.25) - throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError); - if (!(mln.getOutputLayer() instanceof IOutputLayer)) + if (c.maxRelError <= 0.0 || c.maxRelError > 0.25) + throw new IllegalArgumentException("Invalid maxRelativeError: " + c.maxRelError); + if (!(c.net.getOutputLayer() instanceof IOutputLayer)) throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer"); DataType dataType = DataTypeUtil.getDtypeFromContext(); @@ -201,21 +217,21 @@ public class GradientCheckUtil { + "DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil"); } - DataType netDataType = mln.getLayerWiseConfigurations().getDataType(); + DataType netDataType = c.net.getLayerWiseConfigurations().getDataType(); if (netDataType != DataType.DOUBLE) { throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (" + "is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil"); } - if(netDataType != mln.params().dataType()){ + if(netDataType != c.net.params().dataType()){ throw new IllegalStateException("Parameters datatype does not match network configuration datatype (" - + "is: " + mln.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE."); + + "is: " + c.net.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE."); } //Check network configuration: int layerCount = 0; - for (NeuralNetConfiguration n : mln.getLayerWiseConfigurations().getConfs()) { + for (NeuralNetConfiguration n : c.net.getLayerWiseConfigurations().getConfs()) { if (n.getLayer() instanceof BaseLayer) { BaseLayer bl = (BaseLayer) n.getLayer(); IUpdater u = bl.getIUpdater(); @@ -243,7 +259,7 @@ public class GradientCheckUtil { } } - if (n.getLayer().getIDropout() != null && callEachIter == null) { + if (n.getLayer().getIDropout() != null && c.callEachIter == null) { throw new IllegalStateException("When gradient checking dropout, need to reset RNG seed each iter, or no" + " dropout should be present during gradient checks - got dropout = " + n.getLayer().getIDropout() + " for layer " + layerCount); @@ -251,45 +267,45 @@ public class GradientCheckUtil { } //Set softmax clipping to 0 if necessary, to avoid spurious failures due to clipping - for(Layer l : mln.getLayers()){ + for(Layer l : c.net.getLayers()){ if(l instanceof IOutputLayer){ configureLossFnClippingIfPresent((IOutputLayer) l); } } - mln.setInput(input); - mln.setLabels(labels); - mln.setLayerMaskArrays(inputMask, labelMask); - if(callEachIter != null){ - callEachIter.accept(mln); + c.net.setInput(c.input); + c.net.setLabels(c.labels); + c.net.setLayerMaskArrays(c.inputMask, c.labelMask); + if(c.callEachIter != null){ + c.callEachIter.accept(c.net); } - mln.computeGradientAndScore(); - Pair gradAndScore = mln.gradientAndScore(); + c.net.computeGradientAndScore(); + Pair gradAndScore = c.net.gradientAndScore(); - Updater updater = UpdaterCreator.getUpdater(mln); - updater.update(mln, gradAndScore.getFirst(), 0, 0, mln.batchSize(), LayerWorkspaceMgr.noWorkspaces()); + Updater updater = UpdaterCreator.getUpdater(c.net); + updater.update(c.net, gradAndScore.getFirst(), 0, 0, c.net.batchSize(), LayerWorkspaceMgr.noWorkspaces()); INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup(); //need dup: gradients are a *view* of the full gradient array (which will change every time backprop is done) - INDArray originalParams = mln.params().dup(); //need dup: params are a *view* of full parameters + INDArray originalParams = c.net.params().dup(); //need dup: params are a *view* of full parameters val nParams = originalParams.length(); - Map paramTable = mln.paramTable(); + Map paramTable = c.net.paramTable(); List paramNames = new ArrayList<>(paramTable.keySet()); val paramEnds = new long[paramNames.size()]; paramEnds[0] = paramTable.get(paramNames.get(0)).length(); Map stepSizeForParam; - if(subset){ + if(c.subset){ stepSizeForParam = new HashMap<>(); - stepSizeForParam.put(paramNames.get(0), (int) Math.max(1, paramTable.get(paramNames.get(0)).length() / maxPerParam)); + stepSizeForParam.put(paramNames.get(0), (int) Math.max(1, paramTable.get(paramNames.get(0)).length() / c.maxPerParam)); } else { stepSizeForParam = null; } for (int i = 1; i < paramEnds.length; i++) { val n = paramTable.get(paramNames.get(i)).length(); paramEnds[i] = paramEnds[i - 1] + n; - if(subset){ - long ss = n / maxPerParam; + if(c.subset){ + long ss = n / c.maxPerParam; if(ss == 0){ ss = n; } @@ -300,9 +316,9 @@ public class GradientCheckUtil { } } - if(print) { + if(c.print == PrintMode.ALL) { int i=0; - for (Layer l : mln.getLayers()) { + for (Layer l : c.net.getLayers()) { Set s = l.paramTable().keySet(); log.info("Layer " + i + ": " + l.getClass().getSimpleName() + " - params " + s); i++; @@ -312,36 +328,40 @@ public class GradientCheckUtil { int totalNFailures = 0; double maxError = 0.0; - DataSet ds = new DataSet(input, labels, inputMask, labelMask); + DataSet ds = new DataSet(c.input, c.labels, c.inputMask, c.labelMask); int currParamNameIdx = 0; - INDArray params = mln.params(); //Assumption here: params is a view that we can modify in-place + if(c.excludeParams != null && !c.excludeParams.isEmpty()){ + log.info("NOTE: parameters will be skipped due to config: {}", c.excludeParams); + } + + INDArray params = c.net.params(); //Assumption here: params is a view that we can modify in-place for (long i = 0; i < nParams; ) { //Get param name if (i >= paramEnds[currParamNameIdx]) { currParamNameIdx++; } String paramName = paramNames.get(currParamNameIdx); - if(excludeParams != null && excludeParams.contains(paramName)){ - log.info("Skipping parameters for parameter name: {}", paramName); + if(c.excludeParams != null && c.excludeParams.contains(paramName)){ +// log.info("Skipping parameters for parameter name: {}", paramName); i = paramEnds[currParamNameIdx++]; continue; } //(w+epsilon): Do forward pass and score double origValue = params.getDouble(i); - params.putScalar(i, origValue + epsilon); - if(callEachIter != null){ - callEachIter.accept(mln); + params.putScalar(i, origValue + c.epsilon); + if(c.callEachIter != null){ + c.callEachIter.accept(c.net); } - double scorePlus = mln.score(ds, true); + double scorePlus = c.net.score(ds, true); //(w-epsilon): Do forward pass and score - params.putScalar(i, origValue - epsilon); - if(callEachIter != null){ - callEachIter.accept(mln); + params.putScalar(i, origValue - c.epsilon); + if(c.callEachIter != null){ + c.callEachIter.accept(c.net); } - double scoreMinus = mln.score(ds, true); + double scoreMinus = c.net.score(ds, true); //Reset original param value params.putScalar(i, origValue); @@ -349,7 +369,7 @@ public class GradientCheckUtil { //Calculate numerical parameter gradient: double scoreDelta = scorePlus - scoreMinus; - double numericalGradient = scoreDelta / (2 * epsilon); + double numericalGradient = scoreDelta / (2 * c.epsilon); if (Double.isNaN(numericalGradient)) throw new IllegalStateException("Numerical gradient was NaN for parameter " + i + " of " + nParams); @@ -363,30 +383,29 @@ public class GradientCheckUtil { if (relError > maxError) maxError = relError; - if (relError > maxRelError || Double.isNaN(relError)) { + if (relError > c.maxRelError || Double.isNaN(relError)) { double absError = Math.abs(backpropGradient - numericalGradient); - if (absError < minAbsoluteError) { - if(print) { + if (absError < c.minAbsoluteError) { + if(c.print == PrintMode.ALL || c.print == PrintMode.ZEROS && absError == 0.0) { log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError - + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError); + + "; absolute error = " + absError + " < minAbsoluteError = " + c.minAbsoluteError); } } else { - if (print) - log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient - + ", numericalGrad= " + numericalGradient + ", relError= " + relError - + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus + ", paramValue = " + origValue); - if (exitOnFirstError) + log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient + + ", numericalGrad= " + numericalGradient + ", relError= " + relError + + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus + ", paramValue = " + origValue); + if (c.exitOnFirstError) return false; totalNFailures++; } - } else if (print) { + } else if (c.print == PrintMode.ALL) { log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError); } long step; - if(subset){ + if(c.subset){ step = stepSizeForParam.get(paramName); if(i + step > paramEnds[currParamNameIdx]+1){ step = paramEnds[currParamNameIdx]+1 - i; @@ -398,83 +417,25 @@ public class GradientCheckUtil { i += step; } - if (print) { - val nPass = nParams - totalNFailures; - log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + val nPass = nParams - totalNFailures; + log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError); - } return totalNFailures == 0; } - - - /**Check backprop gradients for a ComputationGraph - * @param graph ComputationGraph to test. This must be initialized. - * @param epsilon Usually on the order of 1e-4 or so. - * @param maxRelError Maximum relative error. Usually < 0.01, though maybe more for deep networks - * @param minAbsoluteError Minimum absolute error to cause a failure. Numerical gradients can be non-zero due to precision issues. - * For example, 0.0 vs. 1e-18: relative error is 1.0, but not really a failure - * @param print Whether to print full pass/failure details for each parameter gradient - * @param exitOnFirstError If true: return upon first failure. If false: continue checking even if - * one parameter gradient has failed. Typically use false for debugging, true for unit tests. - * @param inputs Input arrays to use for forward pass. May be mini-batch data. - * @param labels Labels/targets (output) arrays to use to calculate backprop gradient. May be mini-batch data. - * @return true if gradients are passed, false otherwise. - */ - public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, - INDArray[] labels) { - return checkGradients(graph, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, inputs, labels, null, null, null); - } - - public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, - INDArray[] labels, INDArray[] fMask, INDArray[] lMask) { - return checkGradients(graph, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, inputs, - labels, fMask, lMask, null); - } - - public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, - INDArray[] labels, INDArray[] fMask, INDArray[] lMask, Set excludeParams) { - return checkGradients(graph, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, inputs, - labels, fMask, lMask, excludeParams, (Consumer)null); - } - - public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, - INDArray[] labels, INDArray[] fMask, INDArray[] lMask, Set excludeParams, - final Integer rngSeedResetEachIter) { - Consumer c = null; - if(rngSeedResetEachIter != null){ - c = new Consumer() { - @Override - public void accept(ComputationGraph computationGraph) { - Nd4j.getRandom().setSeed(rngSeedResetEachIter); - } - }; - } - - return checkGradients(graph, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, inputs, - labels, fMask, lMask, excludeParams, c); - } - - public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, - double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, - INDArray[] labels, INDArray[] fMask, INDArray[] lMask, Set excludeParams, - Consumer callEachIter) { + public static boolean checkGradients(GraphConfig c){ //Basic sanity checks on input: - if (epsilon <= 0.0 || epsilon > 0.1) + if (c.epsilon <= 0.0 || c.epsilon > 0.1) throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so"); - if (maxRelError <= 0.0 || maxRelError > 0.25) - throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError); + if (c.maxRelError <= 0.0 || c.maxRelError > 0.25) + throw new IllegalArgumentException("Invalid maxRelativeError: " + c.maxRelError); - if (graph.getNumInputArrays() != inputs.length) - throw new IllegalArgumentException("Invalid input arrays: expect " + graph.getNumInputArrays() + " inputs"); - if (graph.getNumOutputArrays() != labels.length) + if (c.net.getNumInputArrays() != c.inputs.length) + throw new IllegalArgumentException("Invalid input arrays: expect " + c.net.getNumInputArrays() + " inputs"); + if (c.net.getNumOutputArrays() != c.labels.length) throw new IllegalArgumentException( - "Invalid labels arrays: expect " + graph.getNumOutputArrays() + " outputs"); + "Invalid labels arrays: expect " + c.net.getNumOutputArrays() + " outputs"); DataType dataType = DataTypeUtil.getDtypeFromContext(); if (dataType != DataType.DOUBLE) { @@ -483,21 +444,21 @@ public class GradientCheckUtil { + "DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil"); } - DataType netDataType = graph.getConfiguration().getDataType(); + DataType netDataType = c.net.getConfiguration().getDataType(); if (netDataType != DataType.DOUBLE) { throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (" + "is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil"); } - if(netDataType != graph.params().dataType()){ + if(netDataType != c.net.params().dataType()){ throw new IllegalStateException("Parameters datatype does not match network configuration datatype (" - + "is: " + graph.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE."); + + "is: " + c.net.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE."); } //Check configuration int layerCount = 0; - for (String vertexName : graph.getConfiguration().getVertices().keySet()) { - GraphVertex gv = graph.getConfiguration().getVertices().get(vertexName); + for (String vertexName : c.net.getConfiguration().getVertices().keySet()) { + GraphVertex gv = c.net.getConfiguration().getVertices().get(vertexName); if (!(gv instanceof LayerVertex)) continue; LayerVertex lv = (LayerVertex) gv; @@ -529,7 +490,7 @@ public class GradientCheckUtil { } } - if (lv.getLayerConf().getLayer().getIDropout() != null && callEachIter == null) { + if (lv.getLayerConf().getLayer().getIDropout() != null && c.callEachIter == null) { throw new IllegalStateException("When gradient checking dropout, rng seed must be reset each iteration, or no" + " dropout should be present during gradient checks - got dropout = " + lv.getLayerConf().getLayer().getIDropout() + " for layer " + layerCount); @@ -537,34 +498,34 @@ public class GradientCheckUtil { } //Set softmax clipping to 0 if necessary, to avoid spurious failures due to clipping - for(Layer l : graph.getLayers()){ + for(Layer l : c.net.getLayers()){ if(l instanceof IOutputLayer){ configureLossFnClippingIfPresent((IOutputLayer) l); } } - for (int i = 0; i < inputs.length; i++) - graph.setInput(i, inputs[i]); - for (int i = 0; i < labels.length; i++) - graph.setLabel(i, labels[i]); + for (int i = 0; i < c.inputs.length; i++) + c.net.setInput(i, c.inputs[i]); + for (int i = 0; i < c.labels.length; i++) + c.net.setLabel(i, c.labels[i]); - graph.setLayerMaskArrays(fMask, lMask); + c.net.setLayerMaskArrays(c.inputMask, c.labelMask); - if(callEachIter != null){ - callEachIter.accept(graph); + if(c.callEachIter != null){ + c.callEachIter.accept(c.net); } - graph.computeGradientAndScore(); - Pair gradAndScore = graph.gradientAndScore(); + c.net.computeGradientAndScore(); + Pair gradAndScore = c.net.gradientAndScore(); - ComputationGraphUpdater updater = new ComputationGraphUpdater(graph); - updater.update(gradAndScore.getFirst(), 0, 0, graph.batchSize(), LayerWorkspaceMgr.noWorkspaces()); + ComputationGraphUpdater updater = new ComputationGraphUpdater(c.net); + updater.update(gradAndScore.getFirst(), 0, 0, c.net.batchSize(), LayerWorkspaceMgr.noWorkspaces()); INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup(); //need dup: gradients are a *view* of the full gradient array (which will change every time backprop is done) - INDArray originalParams = graph.params().dup(); //need dup: params are a *view* of full parameters + INDArray originalParams = c.net.params().dup(); //need dup: params are a *view* of full parameters val nParams = originalParams.length(); - Map paramTable = graph.paramTable(); + Map paramTable = c.net.paramTable(); List paramNames = new ArrayList<>(paramTable.keySet()); val paramEnds = new long[paramNames.size()]; paramEnds[0] = paramTable.get(paramNames.get(0)).length(); @@ -572,19 +533,23 @@ public class GradientCheckUtil { paramEnds[i] = paramEnds[i - 1] + paramTable.get(paramNames.get(i)).length(); } + if(c.excludeParams != null && !c.excludeParams.isEmpty()){ + log.info("NOTE: parameters will be skipped due to config: {}", c.excludeParams); + } + int currParamNameIdx = 0; int totalNFailures = 0; double maxError = 0.0; - MultiDataSet mds = new MultiDataSet(inputs, labels, fMask, lMask); - INDArray params = graph.params(); //Assumption here: params is a view that we can modify in-place + MultiDataSet mds = new MultiDataSet(c.inputs, c.labels, c.inputMask, c.labelMask); + INDArray params = c.net.params(); //Assumption here: params is a view that we can modify in-place for (long i = 0; i < nParams; i++) { //Get param name if (i >= paramEnds[currParamNameIdx]) { currParamNameIdx++; } String paramName = paramNames.get(currParamNameIdx); - if(excludeParams != null && excludeParams.contains(paramName)){ - log.info("Skipping parameters for parameter name: {}", paramName); + if(c.excludeParams != null && c.excludeParams.contains(paramName)){ + //log.info("Skipping parameters for parameter name: {}", paramName); i = paramEnds[currParamNameIdx++]; continue; } @@ -592,18 +557,18 @@ public class GradientCheckUtil { //(w+epsilon): Do forward pass and score double origValue = params.getDouble(i); - params.putScalar(i, origValue + epsilon); - if(callEachIter != null){ - callEachIter.accept(graph); + params.putScalar(i, origValue + c.epsilon); + if(c.callEachIter != null){ + c.callEachIter.accept(c.net); } - double scorePlus = graph.score(mds, true); //training == true for batch norm, etc (scores and gradients need to be calculated on same thing) + double scorePlus = c.net.score(mds, true); //training == true for batch norm, etc (scores and gradients need to be calculated on same thing) //(w-epsilon): Do forward pass and score - params.putScalar(i, origValue - epsilon); - if(callEachIter != null){ - callEachIter.accept(graph); + params.putScalar(i, origValue - c.epsilon); + if(c.callEachIter != null){ + c.callEachIter.accept(c.net); } - double scoreMinus = graph.score(mds, true); + double scoreMinus = c.net.score(mds, true); //Reset original param value params.putScalar(i, origValue); @@ -611,7 +576,7 @@ public class GradientCheckUtil { //Calculate numerical parameter gradient: double scoreDelta = scorePlus - scoreMinus; - double numericalGradient = scoreDelta / (2 * epsilon); + double numericalGradient = scoreDelta / (2 * c.epsilon); if (Double.isNaN(numericalGradient)) throw new IllegalStateException("Numerical gradient was NaN for parameter " + i + " of " + nParams); @@ -625,32 +590,31 @@ public class GradientCheckUtil { if (relError > maxError) maxError = relError; - if (relError > maxRelError || Double.isNaN(relError)) { + if (relError > c.maxRelError || Double.isNaN(relError)) { double absError = Math.abs(backpropGradient - numericalGradient); - if (absError < minAbsoluteError) { - log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient - + ", numericalGrad= " + numericalGradient + ", relError= " + relError - + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError); + if (absError < c.minAbsoluteError) { + if(c.print == PrintMode.ALL || c.print == PrintMode.ZEROS && absError == 0.0) { + log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + + ", numericalGrad= " + numericalGradient + ", relError= " + relError + + "; absolute error = " + absError + " < minAbsoluteError = " + c.minAbsoluteError); + } } else { - if (print) - log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient - + ", numericalGrad= " + numericalGradient + ", relError= " + relError - + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus + ", paramValue = " + origValue); - if (exitOnFirstError) + log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient + + ", numericalGrad= " + numericalGradient + ", relError= " + relError + + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus + ", paramValue = " + origValue); + if (c.exitOnFirstError) return false; totalNFailures++; } - } else if (print) { + } else if (c.print == PrintMode.ALL) { log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError); } } - if (print) { - val nPass = nParams - totalNFailures; - log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " - + totalNFailures + " failed. Largest relative error = " + maxError); - } + val nPass = nParams - totalNFailures; + log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + + totalNFailures + " failed. Largest relative error = " + maxError); return totalNFailures == 0; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java index cc26169cf..dc88116e5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java @@ -119,7 +119,7 @@ public class Convolution3D extends ConvolutionLayer { throw new IllegalStateException("Invalid input for Convolution3D layer (layer name=\"" + getLayerName() + "\"): Expected CNN3D input, got " + inputType); } - return InputTypeUtil.getOutputTypeCnn3DLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, + return InputTypeUtil.getOutputTypeCnn3DLayers(inputType, dataFormat, kernelSize, stride, padding, dilation, convolutionMode, nOut, layerIndex, getLayerName(), Convolution3DLayer.class); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index b0c5bb3d4..9e52981e2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -34,6 +34,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.Map; @@ -118,19 +119,19 @@ public class ConvolutionLayer extends FeedForwardLayer { this.convolutionMode = builder.convolutionMode; this.dilation = builder.dilation; if (builder.kernelSize.length != dim) { - throw new IllegalArgumentException("Kernel argument should be a " + dim + "d array"); + throw new IllegalArgumentException("Kernel argument should be a " + dim + "d array, got " + Arrays.toString(builder.kernelSize)); } this.kernelSize = builder.kernelSize; if (builder.stride.length != dim) { - throw new IllegalArgumentException("Strides argument should be a " + dim + "d array"); + throw new IllegalArgumentException("Strides argument should be a " + dim + "d array, got " + Arrays.toString(builder.stride)); } this.stride = builder.stride; if (builder.padding.length != dim) { - throw new IllegalArgumentException("Padding argument should be a " + dim + "d array"); + throw new IllegalArgumentException("Padding argument should be a " + dim + "d array, got " + Arrays.toString(builder.padding)); } this.padding = builder.padding; if (builder.dilation.length != dim) { - throw new IllegalArgumentException("Dilation argument should be a " + dim + "d array"); + throw new IllegalArgumentException("Dilation argument should be a " + dim + "d array, got " + Arrays.toString(builder.dilation)); } this.dilation = builder.dilation; this.cudnnAlgoMode = builder.cudnnAlgoMode; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java new file mode 100644 index 000000000..01bd3ca83 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java @@ -0,0 +1,219 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.conf.layers; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.ToString; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer; +import org.deeplearning4j.nn.layers.convolution.Deconvolution3DLayer; +import org.deeplearning4j.nn.params.Deconvolution3DParamInitializer; +import org.deeplearning4j.nn.params.DeconvolutionParamInitializer; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.Collection; +import java.util.Map; + +/** + * 3D deconvolution layer configuration
+ * + * Deconvolutions are also known as transpose convolutions or fractionally strided convolutions. In essence, + * deconvolutions swap forward and backward pass with regular 3D convolutions. + * + * See the paper by Matt Zeiler for details: http://www.matthewzeiler.com/wp-content/uploads/2017/07/cvpr2010.pdf + * + * For an intuitive guide to convolution arithmetic and shapes, see: + * https://arxiv.org/abs/1603.07285v1 + * + * @author Alex Black + */ +@Data +@NoArgsConstructor +@ToString(callSuper = true) +@EqualsAndHashCode(callSuper = true) +public class Deconvolution3D extends ConvolutionLayer { + + private Convolution3D.DataFormat dataFormat = Convolution3D.DataFormat.NCDHW; // in libnd4j: 1 - NCDHW, 0 - NDHWC + + /** + * Deconvolution3D layer nIn in the input layer is the number of channels nOut is the number of filters to be used + * in the net or in other words the channels The builder specifies the filter/kernel size, the stride and padding + * The pooling layer takes the kernel size + */ + protected Deconvolution3D(Builder builder) { + super(builder); + this.dataFormat = builder.dataFormat; + initializeConstraints(builder); + } + + public boolean hasBias() { + return hasBias; + } + + @Override + public Deconvolution3D clone() { + Deconvolution3D clone = (Deconvolution3D) super.clone(); + if (clone.kernelSize != null) { + clone.kernelSize = clone.kernelSize.clone(); + } + if (clone.stride != null) { + clone.stride = clone.stride.clone(); + } + if (clone.padding != null) { + clone.padding = clone.padding.clone(); + } + return clone; + } + + @Override + public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerValidation.assertNInNOutSet("Deconvolution2D", getLayerName(), layerIndex, getNIn(), getNOut()); + + Deconvolution3DLayer ret = + new Deconvolution3DLayer(conf, networkDataType); + ret.setListeners(trainingListeners); + ret.setIndex(layerIndex); + ret.setParamsViewArray(layerParamsView); + Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + ret.setParamTable(paramTable); + ret.setConf(conf); + return ret; + } + + @Override + public ParamInitializer initializer() { + return Deconvolution3DParamInitializer.getInstance(); + } + + @Override + public InputPreProcessor getPreProcessorForInputType(InputType inputType) { + if (inputType == null) { + throw new IllegalStateException("Invalid input for Deconvolution3D layer (layer name=\"" + getLayerName() + "\"): input is null"); + } + + return InputTypeUtil.getPreProcessorForInputTypeCnn3DLayers(inputType, getLayerName()); + } + + @Override + public void setNIn(InputType inputType, boolean override) { + if (inputType == null || inputType.getType() != InputType.Type.CNN3D) { + throw new IllegalStateException("Invalid input for Deconvolution 3D layer (layer name=\"" + getLayerName() + "\"): Expected CNN3D input, got " + inputType); + } + + if (nIn <= 0 || override) { + InputType.InputTypeConvolutional3D c = (InputType.InputTypeConvolutional3D) inputType; + this.nIn = c.getChannels(); + } + } + + @Override + public InputType getOutputType(int layerIndex, InputType inputType) { + if (inputType == null || inputType.getType() != InputType.Type.CNN3D) { + throw new IllegalStateException("Invalid input for Deconvolution layer (layer name=\"" + getLayerName() + + "\"): Expected CNN input, got " + inputType); + } + + return InputTypeUtil.getOutputTypeDeconv3dLayer(inputType, kernelSize, stride, padding, dilation, convolutionMode, + dataFormat, nOut, layerIndex, getLayerName(), Deconvolution3DLayer.class); + } + + public static class Builder extends BaseConvBuilder { + + private Convolution3D.DataFormat dataFormat = Convolution3D.DataFormat.NCDHW; // in libnd4j: 1 - NCDHW, 0 - NDHWC + + public Builder() { + super(new int[] {2, 2, 2}, new int[] {1, 1, 1}, new int[] {0, 0, 0}, new int[] {1, 1, 1}, 3); + } + + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + + /** + * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details + * + * @param convolutionMode Convolution mode for layer + */ + public Builder convolutionMode(ConvolutionMode convolutionMode) { + return super.convolutionMode(convolutionMode); + } + + /** + * Size of the convolution rows/columns + * + * @param kernelSize the height and width of the kernel + */ + public Builder kernelSize(int... kernelSize) { + this.setKernelSize(kernelSize); + return this; + } + + public Builder stride(int... stride) { + this.setStride(stride); + return this; + } + + public Builder padding(int... padding) { + this.setPadding(padding); + return this; + } + + @Override + public void setKernelSize(int... kernelSize) { + this.kernelSize = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize"); + } + + @Override + public void setStride(int... stride) { + this.stride = ValidationUtils.validate3NonNegative(stride, "stride"); + } + + @Override + public void setPadding(int... padding) { + this.padding = ValidationUtils.validate3NonNegative(padding, "padding"); + } + + @Override + public void setDilation(int... dilation) { + this.dilation = ValidationUtils.validate3NonNegative(dilation, "dilation"); + } + + public Builder dataFormat(Convolution3D.DataFormat dataFormat){ + this.dataFormat = dataFormat; + return this; + } + + @Override + public Deconvolution3D build() { + return new Deconvolution3D(this); + } + } + +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java index 0227fee23..6cd8630d0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -104,7 +105,6 @@ public class EmbeddingLayer extends FeedForwardLayer { return hasBias; } - @NoArgsConstructor @Getter @Setter public static class Builder extends FeedForwardLayer.Builder { @@ -115,6 +115,13 @@ public class EmbeddingLayer extends FeedForwardLayer { */ private boolean hasBias = false; + public Builder(){ + //Default to Identity activation - i.e., don't inherit. + //For example, if user sets ReLU as global default, they very likely don't intend to use it for Embedding layer also + this.activationFn = new ActivationIdentity(); + } + + /** * If true: include bias parameters in the layer. False (default): no bias. * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java index 3e5766af2..93585a1d0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -138,11 +139,16 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { } - @NoArgsConstructor @Getter @Setter public static class Builder extends FeedForwardLayer.Builder { + public Builder(){ + //Default to Identity activation - i.e., don't inherit. + //For example, if user sets ReLU as global default, they very likely don't intend to use it for Embedding layer also + this.activationFn = new ActivationIdentity(); + } + /** * If true: include bias parameters in the layer. False (default): no bias. * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java index 026f0d350..206071e38 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java @@ -95,9 +95,8 @@ public abstract class FeedForwardLayer extends BaseLayer { case CNN3D: //CNN3D -> FF InputType.InputTypeConvolutional3D c3d = (InputType.InputTypeConvolutional3D) inputType; - //TODO don't hardcode NCDHW return new Cnn3DToFeedForwardPreProcessor(c3d.getDepth(), c3d.getHeight(), c3d.getWidth(), - c3d.getChannels(), true); + c3d.getChannels(), c3d.getDataFormat() == Convolution3D.DataFormat.NCDHW); default: throw new RuntimeException("Unknown input type: " + inputType); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java index 7c97930ae..eb78323b6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java @@ -36,6 +36,8 @@ import java.util.Arrays; @Slf4j public class InputTypeUtil { + private InputTypeUtil(){ } + public static InputType getOutputTypeDeconvLayer(InputType inputType, int[] kernelSize, int[] stride, int[] padding, int[] dilation, ConvolutionMode convolutionMode, long outputDepth, long layerIdx, String layerName, Class layerClass) { @@ -77,9 +79,60 @@ public class InputTypeUtil { return InputType.convolutional(hOut, wOut, outputDepth); } - public static InputType getOutputTypeCnn3DLayers(InputType inputType, int[] kernelSize, int[] stride, int[] padding, - int[] dilation, ConvolutionMode convolutionMode, long outputChannels, long layerIdx, - String layerName, Class layerClass) { + public static InputType getOutputTypeDeconv3dLayer(InputType inputType, int[] kernelSize, int[] stride, int[] padding, + int[] dilation, ConvolutionMode convolutionMode, Convolution3D.DataFormat dataFormat, + long outputDepth, long layerIdx, String layerName, Class layerClass) { + InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType; + + long hIn = i.getHeight(); + long wIn = i.getWidth(); + long dIn = i.getDepth(); + + + int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same + int padW = (padding == null ? 0 : padding[1]); + int padD = (padding == null ? 0 : padding[2]); + int kH = kernelSize[0]; + int kW = kernelSize[1]; + int kD = kernelSize[2]; + if (dilation[0] != 1) { + kH = kH + (kH - 1) * (dilation[0] - 1); + } + if (dilation[1] != 1) { + kW = kW + (kW - 1) * (dilation[1] - 1); + } + if (dilation[2] != 1) { + kD = kD + (kD - 1) * (dilation[2] - 1); + } + + int sH = stride[0]; + int sW = stride[1]; + int sD = stride[2]; + + if (sH <= 0 || sW <= 0 || sD <= 0) { + throw new DL4JInvalidConfigException(getConfigErrorCommonLine(layerIdx, layerName, layerClass, sH <= 0) + + " Invalid strides: strides must be > 0 (strideH = " + sH + ", strideW = " + sW + ", stride = " + sD + ")" + + "\n" + getConfigErrorCommonLastLine(inputType, kernelSize, stride, padding, outputDepth, + convolutionMode)); + } + + if (convolutionMode == ConvolutionMode.Same) { + long hOut = stride[0] * hIn; + long wOut = stride[1] * wIn; + long dOut = stride[2] * dIn; + return InputType.convolutional3D(dataFormat, dOut, hOut, wOut, outputDepth); + } + + long hOut = sH * (hIn - 1) + kH - 2 * padH; + long wOut = sW * (wIn - 1) + kW - 2 * padW; + long dOut = sD * (dIn - 1) + kD - 2 * padD; + + return InputType.convolutional3D(dataFormat, dOut, hOut, wOut, outputDepth); + } + + public static InputType getOutputTypeCnn3DLayers(InputType inputType, Convolution3D.DataFormat dataFormat, int[] kernelSize, int[] stride, int[] padding, + int[] dilation, ConvolutionMode convolutionMode, long outputChannels, long layerIdx, + String layerName, Class layerClass) { if (convolutionMode == null) { String name = layerName == null ? "(not named)" : layerName; throw new DL4JInvalidConfigException("Invalid configuration: convolution mode is null for layer (idx=" @@ -204,7 +257,7 @@ public class InputTypeUtil { int outH = (int) Math.ceil(inHeight / ((double) sH)); int outW = (int) Math.ceil(inWidth / ((double) sW)); - return InputType.convolutional3D(outD, outH, outW, outputChannels); + return InputType.convolutional3D(dataFormat, outD, outH, outW, outputChannels); } long dOut = (inDepth - kD + 2 * padD) / sD + 1; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java index db9a19ecc..428a4c7a6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.PReLUParamInitializer; +import org.deeplearning4j.nn.weights.WeightInitConstant; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; @@ -115,11 +116,15 @@ public class PReLULayer extends BaseLayer { .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS).build(); } - @NoArgsConstructor @Getter @Setter public static class Builder extends FeedForwardLayer.Builder { + public Builder(){ + //Default to 0s, and don't inherit global default + this.weightInitFn = new WeightInitConstant(0); + } + /** * Explicitly set input shape of incoming activations so that parameters can be initialized properly. This * explicitly excludes the mini-batch dimension. diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java index 550e29e4f..b5d73bceb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java @@ -142,7 +142,7 @@ public class Subsampling3DLayer extends NoParamLayer { long inChannels = ((InputType.InputTypeConvolutional3D) inputType).getChannels(); if (inChannels > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - return InputTypeUtil.getOutputTypeCnn3DLayers(inputType, kernelSize, stride, padding, new int[] {1, 1, 1}, // no dilation + return InputTypeUtil.getOutputTypeCnn3DLayers(inputType, dataFormat, kernelSize, stride, padding, new int[] {1, 1, 1}, // no dilation convolutionMode, (int) inChannels, layerIndex, getLayerName(), Subsampling3DLayer.class); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java index dba78df4a..185fd18d3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java @@ -101,7 +101,7 @@ public class Cnn3DToFeedForwardPreProcessor implements InputPreProcessor { throw new IllegalStateException("Invalid input array: expected shape in format " + "[minibatch, channels, channels, height, width] or " + "[minibatch, channels, height, width, channels]" - + "for numChannels: " + numChannels + ", inputDepth " + inputDepth + ", inputHeight " + inputHeight + + " for numChannels: " + numChannels + ", inputDepth " + inputDepth + ", inputHeight " + inputHeight + " and inputWidth " + inputWidth + ", but got " + Arrays.toString(input.shape())); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java index 8be034735..8ae1a8531 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.layers.convolution; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CacheMode; @@ -53,8 +54,8 @@ import java.util.Arrays; * * @author Adam Gibson (original impl), Alex Black (current version) */ +@Slf4j public class ConvolutionLayer extends BaseLayer { - protected static final Logger log = LoggerFactory.getLogger(ConvolutionLayer.class); protected INDArray i2d; protected ConvolutionHelper helper = null; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java index b0db7a5ed..3cb34b6ab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java @@ -70,7 +70,7 @@ public class Deconvolution2DLayer extends ConvolutionLayer { assertInputSet(true); if (input.rank() != 4) { throw new DL4JInvalidInputException("Got rank " + input.rank() - + " array as input to SubsamplingLayer with shape " + Arrays.toString(input.shape()) + + " array as input to Deconvolution2DLayer with shape " + Arrays.toString(input.shape()) + ". Expected rank 4 array with shape [minibatchSize, channels, inputHeight, inputWidth]. " + layerId()); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution3DLayer.java new file mode 100644 index 000000000..b9d9339ea --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution3DLayer.java @@ -0,0 +1,231 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.layers.convolution; + +import lombok.val; +import org.deeplearning4j.exception.DL4JInvalidInputException; +import org.deeplearning4j.nn.conf.CacheMode; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.Convolution3D; +import org.deeplearning4j.nn.conf.layers.Deconvolution3D; +import org.deeplearning4j.nn.gradient.DefaultGradient; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.BaseLayer; +import org.deeplearning4j.nn.params.DeconvolutionParamInitializer; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.util.ConvolutionUtils; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.linalg.util.ArrayUtil; + +import java.util.Arrays; + +/** + * 3D deconvolution layer implementation. + * + * Deconvolutions are also known as transpose convolutions or fractionally strided convolutions. + * In essence, deconvolutions swap forward and backward pass with regular 3D convolutions. + * + * See the paper by Matt Zeiler for details: + * http://www.matthewzeiler.com/wp-content/uploads/2017/07/cvpr2010.pdf + * + * For an intuitive guide to convolution arithmetic and shapes, see: + * https://arxiv.org/abs/1603.07285v1 + * + * + * @author Alex Black + */ +public class Deconvolution3DLayer extends BaseLayer { + + public Deconvolution3DLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); + } + + @Override + public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { + assertInputSet(true); + if (input.rank() != 5) { + throw new DL4JInvalidInputException("Got rank " + input.rank() + + " array as input to Deconvolution3DLayer with shape " + Arrays.toString(input.shape()) + + ". Expected rank 5 array with shape [minibatchSize, channels, inputHeight, inputWidth, inputDepth] or" + + " [minibatchSize, inputHeight, inputWidth, inputDepth, channels]. " + layerId()); + } + + INDArray weights = getParamWithNoise(DeconvolutionParamInitializer.WEIGHT_KEY, true, workspaceMgr); + + Convolution3D.DataFormat df = layerConf().getDataFormat(); + ConvolutionMode cm = layerConf().getConvolutionMode(); + + int[] dilation = layerConf().getDilation(); + int[] kernel = layerConf().getKernelSize(); + int[] strides = layerConf().getStride(); + int[] pad = layerConf().getPadding(); + + INDArray biasGradView = gradientViews.get(DeconvolutionParamInitializer.BIAS_KEY); + INDArray weightGradView = gradientViews.get(DeconvolutionParamInitializer.WEIGHT_KEY); + + INDArray outEps = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, weights.dataType(), input.shape(), 'c'); + + Integer sameMode = (layerConf().getConvolutionMode() == ConvolutionMode.Same) ? 1 : 0; + + int[] args = new int[] { + kernel[0], kernel[1], kernel[2], strides[0], strides[1], strides[2], + pad[0], pad[1], pad[2], dilation[0], dilation[1], dilation[2], sameMode, + df == Convolution3D.DataFormat.NCDHW ? 0 : 1 + }; + + INDArray delta; + IActivation afn = layerConf().getActivationFn(); + INDArray preOutput = preOutput(true, workspaceMgr); + delta = afn.backprop(preOutput, epsilon).getFirst(); + + INDArray[] opInputs; + INDArray[] opOutputs; + if(layerConf().hasBias()){ + INDArray bias = getParamWithNoise(DeconvolutionParamInitializer.BIAS_KEY, true, workspaceMgr); + opInputs = new INDArray[]{input, weights, bias, delta}; + opOutputs = new INDArray[]{outEps, weightGradView, biasGradView}; + } else { + opInputs = new INDArray[]{input, weights, delta}; + opOutputs = new INDArray[]{outEps, weightGradView}; + } + CustomOp op = DynamicCustomOp.builder("deconv3d_bp") + .addInputs(opInputs) + .addIntegerArguments(args) + .addOutputs(opOutputs) + .callInplace(false) + .build(); + Nd4j.getExecutioner().exec(op); + + + Gradient retGradient = new DefaultGradient(); + if(layerConf().hasBias()){ + retGradient.setGradientFor(DeconvolutionParamInitializer.BIAS_KEY, biasGradView); + } + retGradient.setGradientFor(DeconvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c'); + weightNoiseParams.clear(); + + return new Pair<>(retGradient, outEps); + } + + protected INDArray preOutput(boolean training , LayerWorkspaceMgr workspaceMgr) { + + INDArray bias = getParamWithNoise(DeconvolutionParamInitializer.BIAS_KEY, training, workspaceMgr); + INDArray weights = getParamWithNoise(DeconvolutionParamInitializer.WEIGHT_KEY, training, workspaceMgr); + + //Input validation: expect rank 5 matrix + if (input.rank() != 5) { + throw new DL4JInvalidInputException("Got rank " + input.rank() + + " array as input to Deconvolution3DLayer with shape " + Arrays.toString(input.shape()) + + ". Expected rank 5 array with shape [minibatchSize, channels, inputHeight, inputWidth, inputDepth] or" + + " [minibatchSize, inputHeight, inputWidth, inputDepth, channels]. " + layerId()); + } + + Convolution3D.DataFormat df = layerConf().getDataFormat(); + boolean ncdhw = layerConf().getDataFormat() == Convolution3D.DataFormat.NCDHW; + int chDim = ncdhw ? 1 : 4; + if (input.size(chDim) != layerConf().getNIn() ) { + String layerName = conf.getLayer().getLayerName(); + if (layerName == null) + layerName = "(not named)"; + throw new DL4JInvalidInputException("Cannot do forward pass in Deconvolution3D layer (layer name = " + layerName + + ", layer index = " + index + "): input array channels does not match CNN layer configuration" + + " (data input channels = " + input.size(chDim) + ", " + (ncdhw ? "[minibatch,channels,height,width,depth]=" : "[minibatch,height,width,depth,channels]=") + + Arrays.toString(input.shape()) + "; expected" + " input channels = " + layerConf().getNIn() + ") " + + layerId()); + } + + int[] dilation = layerConf().getDilation(); + int[] kernel = layerConf().getKernelSize(); + int[] strides = layerConf().getStride(); + + int[] pad; + ConvolutionMode cm = layerConf().getConvolutionMode(); + long[] outSize; + int[] inSize = df == Convolution3D.DataFormat.NCDHW ? new int[]{(int)input.size(2), (int)input.size(3), (int)input.size(4)} : new int[]{(int)input.size(1), (int)input.size(2), (int)input.size(3)}; + if (cm == ConvolutionMode.Same) { + outSize = ConvolutionUtils.getDeconvolution3DOutputSize(input, kernel, strides, null, dilation, cm, layerConf().getDataFormat()); //Also performs validation + pad = ConvolutionUtils.getSameModeTopLeftPadding(ArrayUtil.toInts(outSize), inSize, kernel, strides, dilation ); + } else { + pad = layerConf().getPadding(); + outSize = ConvolutionUtils.getDeconvolution3DOutputSize(input, kernel, strides, pad, dilation, cm, layerConf().getDataFormat()); //Also performs validation + } + + long outH = outSize[0]; + long outW = outSize[1]; + long outD = outSize[2]; + + + val miniBatch = input.size(0); + long[] outShape = df == Convolution3D.DataFormat.NCDHW ? new long[]{miniBatch, layerConf().getNOut(), outH, outW, outD} : new long[]{miniBatch, outH, outW, outD, layerConf().getNOut()}; + INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c'); + + int sameMode = (cm == ConvolutionMode.Same) ? 1 : 0; + + int[] args = new int[] { + kernel[0], kernel[1], kernel[2], strides[0], strides[1], strides[2], + pad[0], pad[1], pad[2], dilation[0], dilation[1], dilation[2], sameMode, + df == Convolution3D.DataFormat.NCDHW ? 0 : 1 + }; + + INDArray[] opInputs; + if (layerConf().hasBias()) { + opInputs = new INDArray[]{input, weights, bias}; + } else { + opInputs = new INDArray[]{input, weights}; + } + CustomOp op = DynamicCustomOp.builder("deconv3d") + .addInputs(opInputs) + .addIntegerArguments(args) + .addOutputs(output) + .callInplace(false) + .build(); + Nd4j.getExecutioner().exec(op); + + return output; + } + + @Override + public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { + assertInputSet(false); + + if (cacheMode == null) + cacheMode = CacheMode.NONE; + + applyDropOutIfNecessary(training, workspaceMgr); + + INDArray z = preOutput(training, workspaceMgr); + + IActivation afn = layerConf().getActivationFn(); + + INDArray activation = afn.getActivation(z, training); + return activation; + } + + @Override + public boolean isPretrainLayer() { + return false; + } +} \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java index b35d946aa..4a0fc6aa0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -82,7 +83,10 @@ public class PReLU extends BaseLayer deltas = prelu.backprop(layerInput, epsilon); INDArray delta = deltas.getFirst(); - INDArray weightGradView = deltas.getSecond(); + INDArray weightGrad = deltas.getSecond(); + INDArray weightGradView = gradientViews.get(PReLUParamInitializer.WEIGHT_KEY); + weightGradView.assign(weightGrad); + delta = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta); //Usually a no-op (except for perhaps identity) delta = backpropDropOutIfPresent(delta); @@ -98,4 +102,4 @@ public class PReLU extends BaseLayer init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { + Deconvolution3D layer = (Deconvolution3D) conf.getLayer(); + if (layer.getKernelSize().length != 3) throw new IllegalArgumentException("Filter size must be == 3"); + + Map params = Collections.synchronizedMap(new LinkedHashMap()); + + Deconvolution3D layerConf = (Deconvolution3D) conf.getLayer(); + val nOut = layerConf.getNOut(); + + if (layer.hasBias()) { + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); + INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, numParams(conf))); + params.put(BIAS_KEY, createBias(conf, biasView, initializeParams)); + params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); + conf.addVariable(WEIGHT_KEY); + conf.addVariable(BIAS_KEY); + } else { + INDArray weightView = paramsView; + params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); + conf.addVariable(WEIGHT_KEY); + } + + return params; + } + + @Override + public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + + Deconvolution3D layerConf = (Deconvolution3D) conf.getLayer(); + + int[] kernel = layerConf.getKernelSize(); + val nIn = layerConf.getNIn(); + val nOut = layerConf.getNOut(); + + Map out = new LinkedHashMap<>(); + if (layerConf.hasBias()) { + INDArray biasGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); + INDArray weightGradientView = + gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, numParams(conf))) + .reshape('c', kernel[0], kernel[1], kernel[2], nOut, nIn); + out.put(BIAS_KEY, biasGradientView); + out.put(WEIGHT_KEY, weightGradientView); + } else { + INDArray weightGradientView = gradientView.reshape('c', kernel[0], kernel[1], kernel[2], nOut, nIn); + out.put(WEIGHT_KEY, weightGradientView); + } + return out; + } + + + protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) { + /* + Create a 5d weight matrix of: + (number of kernels, num input channels, kernel depth, kernel height, kernel width) + Note c order is used specifically for the CNN weights, as opposed to f order elsewhere + Inputs to the convolution layer are: + (batch size, num input feature maps, image depth, image height, image width) + */ + Deconvolution3D layerConf = (Deconvolution3D) conf.getLayer(); + + if (initializeParams) { + int[] kernel = layerConf.getKernelSize(); + int[] stride = layerConf.getStride(); + + val inputDepth = layerConf.getNIn(); + val outputDepth = layerConf.getNOut(); + + double fanIn = inputDepth * kernel[0] * kernel[1] * kernel[2]; + double fanOut = outputDepth * kernel[0] * kernel[1] * kernel[2] / + ((double) stride[0] * stride[1] * stride[2]); + + //libnd4j: [kD, kH, kW, oC, iC] + val weightsShape = new long[]{kernel[0], kernel[1], kernel[2], outputDepth, inputDepth}; + + return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView); + } else { + int[] kernel = layerConf.getKernelSize(); + return WeightInitUtil.reshapeWeights( + new long[]{kernel[0], kernel[1], kernel[2], layerConf.getNOut(), layerConf.getNIn()}, weightView, 'c'); + } + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/ParamAndGradientIterationListener.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/ParamAndGradientIterationListener.java deleted file mode 100644 index fe4d01b1a..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/ParamAndGradientIterationListener.java +++ /dev/null @@ -1,235 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.optimize.listeners; - -import lombok.Builder; -import org.deeplearning4j.nn.api.Model; -import org.deeplearning4j.optimize.api.BaseTrainingListener; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.ops.transforms.Transforms; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.StandardOpenOption; -import java.util.Map; - -/** - * An iteration listener that provides details on parameters and gradients at each iteration during traning. - * Attempts to provide much of the same information as the UI histogram iteration listener, but in a text-based - * format (for example, when learning on a system accessed via SSH etc). - * i.e., is intended to aid network tuning and debugging
- * This iteration listener is set up to calculate mean, min, max, and mean absolute value - * of each type of parameter and gradient in the network at each iteration.
- * - * @author Alex Black - * @deprecated StatsListener can be used instead, storing data using FileStatsStorage - UI is not required - */ -public class ParamAndGradientIterationListener extends BaseTrainingListener { - private static final int MAX_WRITE_FAILURE_MESSAGES = 10; - private static final Logger logger = LoggerFactory.getLogger(ParamAndGradientIterationListener.class); - - private int iterations; - private long totalIterationCount = 0; - private boolean printMean = true; - private boolean printHeader = true; - private boolean printMinMax = true; - private boolean printMeanAbsValue = true; - private File file; - private Path filePath; - private boolean outputToConsole; - private boolean outputToFile; - private boolean outputToLogger; - private String delimiter = "\t"; - - - private int writeFailureCount = 0; - - - /** Default constructor for output to console only every iteration, tab delimited */ - public ParamAndGradientIterationListener() { - this(1, true, true, true, true, true, false, false, null, "\t"); - } - - /**Full constructor with all options. - * Note also: ParamAndGradientIterationListener.builder() can be used instead of this constructor. - * @param iterations calculate and report values every 'iterations' iterations - * @param printHeader Whether to output a header row (i.e., names for each column) - * @param printMean Calculate and display the mean of parameters and gradients - * @param printMinMax Calculate and display the min/max of the parameters and gradients - * @param printMeanAbsValue Calculate and display the mean absolute value - * @param outputToConsole If true, display the values to the console (System.out.println()) - * @param outputToFile If true, write the values to a file, one per line - * @param outputToLogger If true, log the values - * @param file File to write values to. May be null, not used if outputToFile == false - * @param delimiter delimiter (for example, "\t" or "," etc) - */ - @Builder - public ParamAndGradientIterationListener(int iterations, boolean printHeader, boolean printMean, - boolean printMinMax, boolean printMeanAbsValue, boolean outputToConsole, boolean outputToFile, - boolean outputToLogger, File file, String delimiter) { - this.printHeader = printHeader; - this.printMean = printMean; - this.printMinMax = printMinMax; - this.printMeanAbsValue = printMeanAbsValue; - this.iterations = iterations; - this.file = file; - if (this.file != null) { - this.filePath = file.toPath(); - } - this.outputToConsole = outputToConsole; - this.outputToFile = outputToFile; - this.outputToLogger = outputToLogger; - this.delimiter = delimiter; - } - - @Override - public void iterationDone(Model model, int iteration, int epoch) { - totalIterationCount++; - - if (totalIterationCount == 1 && printHeader) { - Map params = model.paramTable(); - model.conf().getVariables(); - - StringBuilder sb = new StringBuilder(); - - sb.append("n"); - sb.append(delimiter); - sb.append("score"); - - for (String s : params.keySet()) { - //Parameters: - if (printMean) - sb.append(delimiter).append(s).append("_mean"); - //Min, max - if (printMinMax) { - sb.append(delimiter).append(s).append("_min").append(delimiter).append(s).append("_max"); - } - if (printMeanAbsValue) - sb.append(delimiter).append(s).append("_meanAbsValue"); - - //Gradients: - if (printMean) - sb.append(delimiter).append(s).append("_meanG"); - //Min, max - if (printMinMax) { - sb.append(delimiter).append(s).append("_minG").append(delimiter).append(s).append("_maxG"); - } - if (printMeanAbsValue) - sb.append(delimiter).append(s).append("_meanAbsValueG"); - } - sb.append("\n"); - - if (outputToFile) { - try { - Files.write(filePath, sb.toString().getBytes(), StandardOpenOption.CREATE, - StandardOpenOption.TRUNCATE_EXISTING); - } catch (IOException e) { - if (writeFailureCount++ < MAX_WRITE_FAILURE_MESSAGES) { - //Print error message - logger.warn("Error writing to file: {}", e); - } - if (writeFailureCount == MAX_WRITE_FAILURE_MESSAGES) { - logger.warn("Max file write messages displayed. No more failure messages will be printed"); - } - } - } - - if (outputToLogger) - logger.info(sb.toString()); - if (outputToConsole) - System.out.println(sb.toString()); - } - - if (totalIterationCount % iterations != 0) - return; //No op this iteration - - Map params = model.paramTable(); - Map grads = model.gradient().gradientForVariable(); - - StringBuilder sb = new StringBuilder(); - sb.append(totalIterationCount); - sb.append(delimiter); - sb.append(model.score()); - - - //Calculate actual values for parameters and gradients - for (Map.Entry entry : params.entrySet()) { - INDArray currParams = entry.getValue(); - INDArray currGrad = grads.get(entry.getKey()); - - //Parameters: - if (printMean) { - sb.append(delimiter); - sb.append(currParams.meanNumber().doubleValue()); - } - if (printMinMax) { - sb.append(delimiter); - sb.append(currParams.minNumber().doubleValue()); - sb.append(delimiter); - sb.append(currParams.maxNumber().doubleValue()); - } - if (printMeanAbsValue) { - sb.append(delimiter); - INDArray abs = Transforms.abs(currParams.dup()); - sb.append(abs.meanNumber().doubleValue()); - } - - //Gradients: - if (printMean) { - sb.append(delimiter); - sb.append(currGrad.meanNumber().doubleValue()); - } - if (printMinMax) { - sb.append(delimiter); - sb.append(currGrad.minNumber().doubleValue()); - sb.append(delimiter); - sb.append(currGrad.maxNumber().doubleValue()); - } - if (printMeanAbsValue) { - sb.append(delimiter); - INDArray abs = Transforms.abs(currGrad.dup()); - sb.append(abs.meanNumber().doubleValue()); - } - } - sb.append("\n"); - - String out = sb.toString(); - if (outputToLogger) - logger.info(out); - if (outputToConsole) - System.out.print(out); - - if (outputToFile) { - try { - Files.write(filePath, out.getBytes(), StandardOpenOption.CREATE, StandardOpenOption.APPEND); - } catch (IOException e) { - if (writeFailureCount++ < MAX_WRITE_FAILURE_MESSAGES) { - //Print error message - logger.warn("Error writing to file: {}", e); - } - if (writeFailureCount == MAX_WRITE_FAILURE_MESSAGES) { - logger.warn("Max file write messages displayed. No more failure messages will be printed"); - } - } - } - - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java index 760d2076c..7448581c5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java @@ -63,6 +63,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist protected int parties; @Getter protected MessageHandler handler; + @Getter protected List> messages = new ArrayList<>(); protected List workspaces = new ArrayList<>(); protected List locks = new ArrayList<>(); @@ -106,7 +107,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist this(parties, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, 1.0, encodingDebugMode), DEFAULT_INITIAL_MEMORY, 10, 1.0, encodingDebugMode); } - protected EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory, + public EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory, int queueSize, Double boundary, boolean encodingDebugMode) { this.parties = parties; this.handler = handler; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTail.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTail.java index b53986102..c04440cd9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTail.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTail.java @@ -16,6 +16,7 @@ package org.deeplearning4j.optimize.solvers.accumulation; +import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; @@ -44,9 +45,11 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; @Slf4j public class IndexedTail { // here we store positions of individual consumers + @Getter protected ConcurrentHashMap positions = new ConcurrentHashMap<>(); // here we store individual updates + @Getter protected Map updates = new ConcurrentHashMap<>(); // simple counter for new updates @@ -67,6 +70,7 @@ public class IndexedTail { protected final boolean allowCollapse; protected final long[] shape; protected final int collapseThreshold = 32; + @Getter protected AtomicBoolean collapsedMode = new AtomicBoolean(false); protected AtomicLong collapsedIndex = new AtomicLong(-1); @@ -148,7 +152,7 @@ public class IndexedTail { } } - protected long firstNotAppliedIndexEverywhere() { + public long firstNotAppliedIndexEverywhere() { long maxIdx = -1; // if there's no updates posted yet - just return negative value @@ -163,7 +167,7 @@ public class IndexedTail { return maxIdx + 1; } - protected long maxAppliedIndexEverywhere() { + public long maxAppliedIndexEverywhere() { long maxIdx = Long.MAX_VALUE; for (val v:positions.values()) { if (v.get() < maxIdx) @@ -212,7 +216,7 @@ public class IndexedTail { return getDelta(Thread.currentThread().getId()); } - protected long getDelta(long threadId) { + public long getDelta(long threadId) { return getGlobalPosition() - getLocalPosition(threadId); } @@ -293,7 +297,7 @@ public class IndexedTail { /** * This method does maintenance of updates within */ - protected synchronized void maintenance() { + public synchronized void maintenance() { // first of all we're checking, if all consumers were already registered. if not - just no-op. if (positions.size() < expectedConsumers) { log.trace("Skipping maintanance due to not all expected consumers shown up: [{}] vs [{}]", positions.size(), expectedConsumers); @@ -326,7 +330,7 @@ public class IndexedTail { * This method returns actual number of updates stored within tail * @return */ - protected int updatesSize() { + public int updatesSize() { return updates.size(); } @@ -348,11 +352,11 @@ public class IndexedTail { return result; } - protected boolean isDead() { + public boolean isDead() { return dead.get(); } - protected void notifyDead() { + public void notifyDead() { dead.set(true); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java index 3a447c361..399af4b2d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java @@ -94,6 +94,49 @@ public class ConvolutionUtils { return new int[]{hOut, wOut}; } + /** + * Get the output size of a deconvolution operation for given input data. In deconvolution, we compute the inverse + * of the shape computation of a convolution. + * + * @param inputData Input data + * @param kernel Kernel size (height/width) + * @param strides Strides (height/width) + * @param padding Padding (height/width) + * @param convolutionMode Convolution mode (Same, Strict, Truncate) + * @param dilation Kernel dilation (height/width) + * @return Output size: int[2] with output height/width + */ + public static long[] getDeconvolution3DOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, int[] dilation, + ConvolutionMode convolutionMode, Convolution3D.DataFormat dataFormat) { + + long hIn, wIn, dIn; + if(dataFormat == Convolution3D.DataFormat.NCDHW){ + hIn = inputData.size(2); + wIn = inputData.size(3); + dIn = inputData.size(4); + } else { + hIn = inputData.size(1); + wIn = inputData.size(2); + dIn = inputData.size(3); + } + + + int[] eKernel = effectiveKernelSize(kernel, dilation); + + if (convolutionMode == ConvolutionMode.Same) { + long hOut = strides[0] * hIn; + long wOut = strides[1] * wIn; + long dOut = strides[2] * dIn; + return new long[]{hOut, wOut, dOut}; + } + + long hOut = strides[0] * (hIn - 1) + eKernel[0] - 2 * padding[0]; + long wOut = strides[1] * (wIn - 1) + eKernel[1] - 2 * padding[1]; + long dOut = strides[2] * (dIn - 1) + eKernel[2] - 2 * padding[2]; + + return new long[]{hOut, wOut, dOut}; + } + /** * Get the output size (height/width) for the given input data and CNN configuration @@ -307,11 +350,15 @@ public class ConvolutionUtils { */ public static int[] getSameModeTopLeftPadding(int[] outSize, int[] inSize, int[] kernel, int[] strides, int[] dilation) { int[] eKernel = effectiveKernelSize(kernel, dilation); - int[] outPad = new int[2]; - outPad[0] = ((outSize[0] - 1) * strides[0] + eKernel[0] - inSize[0]) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 - outPad[1] = ((outSize[1] - 1) * strides[1] + eKernel[1] - inSize[1]) / 2; //As above - Preconditions.checkState(outPad[0] >= 0 && outPad[1] >= 0, "Invalid padding values calculated: %s - layer configuration is invalid? Input size %s, output size %s, kernel %s, strides %s, dilation %s", + int[] outPad = new int[kernel.length]; + boolean allGt0 = true; + for( int i=0; i= 0; + } + Preconditions.checkState(allGt0, "Invalid padding values calculated: %s - layer configuration is invalid? Input size %s, output size %s, kernel %s, strides %s, dilation %s", outPad, inSize, outSize, kernel, strides, dilation); + return outPad; } diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/BaseDL4JTest.java deleted file mode 100644 index 05d0957fb..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/BaseDL4JTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/TestUtils.java deleted file mode 100644 index c60822ef7..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/TestUtils.java +++ /dev/null @@ -1,158 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j; - -import org.apache.commons.compress.utils.IOUtils; -import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.util.ModelSerializer; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.*; -import java.util.Random; - -import static org.junit.Assert.assertEquals; - -public class TestUtils { - - public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){ - - MultiLayerNetwork restored; - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ModelSerializer.writeModel(net, baos, true); - byte[] bytes = baos.toByteArray(); - - ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - - assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); - assertEquals(net.params(), restored.params()); - } catch (IOException e){ - //Should never happen - throw new RuntimeException(e); - } - - //Also check the MultiLayerConfiguration is serializable (required by Spark etc) - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - serializeDeserializeJava(conf); - - return restored; - } - - public static ComputationGraph testModelSerialization(ComputationGraph net){ - - ComputationGraph restored; - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ModelSerializer.writeModel(net, baos, true); - byte[] bytes = baos.toByteArray(); - - ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - restored = ModelSerializer.restoreComputationGraph(bais, true); - - assertEquals(net.getConfiguration(), restored.getConfiguration()); - assertEquals(net.params(), restored.params()); - } catch (IOException e){ - //Should never happen - throw new RuntimeException(e); - } - - //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) - ComputationGraphConfiguration conf = net.getConfiguration(); - serializeDeserializeJava(conf); - - return restored; - } - - private static T serializeDeserializeJava(T object){ - byte[] bytes; - try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){ - oos.writeObject(object); - oos.close(); - bytes = baos.toByteArray(); - } catch (IOException e){ - //Should never happen - throw new RuntimeException(e); - } - - T out; - try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){ - out = (T)ois.readObject(); - } catch (IOException | ClassNotFoundException e){ - throw new RuntimeException(e); - } - - assertEquals(object, out); - return out; - } - - public static INDArray randomOneHot(long examples, long nOut){ - return randomOneHot(examples, nOut, new Random(12345)); - } - - public static INDArray randomOneHot(long examples, long nOut, long rngSeed){ - return randomOneHot(examples, nOut, new Random(rngSeed)); - } - - public static INDArray randomOneHot(long examples, long nOut, Random rng){ - INDArray arr = Nd4j.create(examples, nOut); - for( int i=0; i${logback.version} test + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java index be2633b87..c57b0fa30 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java @@ -2,6 +2,7 @@ package org.deeplearning4j.remote; import lombok.val; import org.datavec.image.loader.Java2DNativeImageLoader; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.remote.helpers.ImageConversionUtils; @@ -30,7 +31,7 @@ import java.util.concurrent.TimeUnit; import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL; import static org.junit.Assert.*; -public class BinaryModelServerTest { +public class BinaryModelServerTest extends BaseDL4JTest { private final int PORT = 18080; @After diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java index aa353f307..8b060a77c 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java @@ -21,6 +21,7 @@ import lombok.Data; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -34,6 +35,7 @@ import org.deeplearning4j.remote.helpers.House; import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter; import org.deeplearning4j.remote.helpers.PredictedPrice; import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.nd4j.adapters.InferenceAdapter; import org.nd4j.autodiff.samediff.SDVariable; @@ -57,15 +59,15 @@ import java.util.Collections; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE; import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL; import static org.junit.Assert.*; @Slf4j -public class JsonModelServerTest { +public class JsonModelServerTest extends BaseDL4JTest { private static final MultiLayerNetwork model; - private final int PORT = 18080; static { val conf = new NeuralNetConfiguration.Builder() @@ -83,10 +85,18 @@ public class JsonModelServerTest { @After public void pause() throws Exception { - // TODO: the same port was used in previous test and not accessible immediately. Might be better solution. + // Need to wait for server shutdown; without sleep, tests will fail if starting immediately after shutdown TimeUnit.SECONDS.sleep(2); } + private AtomicInteger portCount = new AtomicInteger(18080); + private int PORT; + + @Before + public void setPort(){ + PORT = portCount.getAndIncrement(); + } + @Test public void testStartStopParallel() throws Exception { @@ -342,7 +352,7 @@ public class JsonModelServerTest { val server = new JsonModelServer.Builder(model) .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) .inputDeserializer(null) - .port(18080) + .port(PORT) .build(); } @@ -381,7 +391,7 @@ public class JsonModelServerTest { return null; } }) - .endpointAddress("http://localhost:18080/v1/serving") + .endpointAddress("http://localhost:" + PORT + "/v1/serving") .build(); int district = 2; diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java index 1b347d112..ede253efa 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java @@ -20,6 +20,7 @@ import lombok.val; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; import org.apache.http.impl.client.HttpClientBuilder; +import org.deeplearning4j.BaseDL4JTest; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -34,7 +35,7 @@ import java.io.IOException; import static org.junit.Assert.assertEquals; -public class ServletTest { +public class ServletTest extends BaseDL4JTest { private JsonModelServer server; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml deleted file mode 100644 index 969025e50..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml +++ /dev/null @@ -1,99 +0,0 @@ - - - - 4.0.0 - - org.deeplearning4j - deeplearning4j-scaleout - 1.0.0-SNAPSHOT - - - deeplearning4j-aws_2.11 - DeepLearning4j-AWS - 1.0.0-SNAPSHOT - - - - - maven-compiler-plugin - - 1.8 - 1.8 - - - - - - - 2.11.12 - 2.11 - - - - - com.amazonaws - aws-java-sdk - 1.11.24 - - - args4j - args4j - 2.32 - - - org.slf4j - slf4j-api - - - org.nd4j - nd4j-api - ${nd4j.version} - - - org.deeplearning4j - deeplearning4j-util - ${project.version} - - - - com.jcraft - jsch - ${jsch.version} - - - - org.threadly - threadly - ${threadly.version} - - - - org.apache.commons - commons-lang3 - ${commons-lang3.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-10.2 - - - diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/Ec2BoxCreator.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/Ec2BoxCreator.java deleted file mode 100755 index 2d0a58685..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/Ec2BoxCreator.java +++ /dev/null @@ -1,222 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.ec2; - -import com.amazonaws.regions.Regions; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.*; -import org.deeplearning4j.aws.s3.BaseS3; -import org.deeplearning4j.util.ThreadUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -/** - * Creates Ec2Boxes - * @author Adam Gibson - * - */ -public class Ec2BoxCreator extends BaseS3 { - - - private String amiId; - private int numBoxes; - private String size; - private List boxesCreated; - private String securityGroupId; - private String keyPair; - private Regions regions = Regions.DEFAULT_REGION; - private static final Logger log = LoggerFactory.getLogger(Ec2BoxCreator.class); - - //centos - public final static String DEFAULT_AMI = "ami-8997afe0"; - - /** - * - * @param numBoxes number of boxes - * @param size the size of the instances - */ - public Ec2BoxCreator(int numBoxes, String size, String securityGroupId, String keyPair) { - this(DEFAULT_AMI, numBoxes, size, securityGroupId, keyPair); - } - - - /** - * - * @param amiId amazon image id - * @param numBoxes number of boxes - * @param size the size of the instances - * @param securityGroupId - */ - public Ec2BoxCreator(String amiId, int numBoxes, String size, String securityGroupId, String keyPair) { - super(); - this.amiId = amiId; - this.numBoxes = numBoxes; - this.size = size; - this.keyPair = keyPair; - this.securityGroupId = securityGroupId; - } - - public void createSpot() { - // Initializes a Spot Instance Request - RequestSpotInstancesRequest requestRequest = new RequestSpotInstancesRequest(); - - // Request 1 x t1.micro instance with a bid price of $0.03. - requestRequest.setSpotPrice("0.03"); - requestRequest.setInstanceCount(Integer.valueOf(1)); - - // Setup the specifications of the launch. This includes the - // instance type (e.g. t1.micro) and the latest Amazon Linux - // AMI id available. Note, you should always use the latest - // Amazon Linux AMI id or another of your choosing. - LaunchSpecification launchSpecification = new LaunchSpecification(); - launchSpecification.setImageId("ami-8c1fece5"); - launchSpecification.setInstanceType("t1.micro"); - - // Add the security group to the request. - List securityGroups = new ArrayList<>(); - securityGroups.add("GettingStartedGroup"); - launchSpecification.setSecurityGroups(securityGroups); - - // Add the launch specifications to the request. - requestRequest.setLaunchSpecification(launchSpecification); - - // Call the RequestSpotInstance API. - RequestSpotInstancesResult requestResult = getEc2().requestSpotInstances(requestRequest); - - - List requestResponses = requestResult.getSpotInstanceRequests(); - - // Setup an arraylist to collect all of the request ids we want to - // watch hit the running state. - List spotInstanceRequestIds = new ArrayList<>(); - - // Add all of the request ids to the hashset, so we can determine when they hit the - // active state. - for (SpotInstanceRequest requestResponse : requestResponses) { - System.out.println("Created Spot Request: " + requestResponse.getSpotInstanceRequestId()); - spotInstanceRequestIds.add(requestResponse.getSpotInstanceRequestId()); - } - - } - - public void setRegion(Regions regions) { - this.regions = regions; - } - - - /** - * Create the instances - */ - public void create() { - RunInstancesRequest runInstancesRequest = - new RunInstancesRequest().withImageId(amiId).withInstanceType(size).withKeyName(keyPair) - .withMinCount(1).withSecurityGroupIds(securityGroupId).withMaxCount(numBoxes); - AmazonEC2 ec2 = getEc2(); - ec2.setRegion(com.amazonaws.regions.Region.getRegion(regions)); - List boxes = ec2.runInstances(runInstancesRequest).getReservation().getInstances(); - if (boxesCreated == null) { - boxesCreated = new ArrayList<>(); - for (Instance i : boxes) - boxesCreated.add(i.getInstanceId()); - - - - log.info("Boxes created " + boxesCreated); - } else { - blowupBoxes(); - boxesCreated.clear(); - for (Instance i : boxes) - boxesCreated.add(i.getInstanceId()); - - } - } - - - - public List blowupBoxes() { - TerminateInstancesRequest request = new TerminateInstancesRequest().withInstanceIds(boxesCreated); - - if (boxesCreated != null) { - TerminateInstancesResult result = getEc2().terminateInstances(request); - List change = result.getTerminatingInstances(); - log.info("Boxes destroyed " + boxesCreated); - return change; - } - - return Collections.emptyList(); - } - - - public void blockTillAllRunning() { - while (!allRunning()) { - ThreadUtils.uncheckedSleep(1000); - log.info("Not all created..."); - } - } - - public boolean allRunning() { - if (boxesCreated == null) - return false; - else { - DescribeInstancesResult result = getEc2().describeInstances(); - List reservations = result.getReservations(); - for (Reservation r : reservations) { - List instances = r.getInstances(); - for (Instance instance : instances) { - InstanceState state = instance.getState(); - if (state.getCode() == 48) - continue; - if (state.getCode() != 16) - return false; - } - } - - return true; - } - - - } - - public List getHosts() { - DescribeInstancesResult result = getEc2().describeInstances(); - List hosts = new ArrayList<>(); - List reservations = result.getReservations(); - for (Reservation r : reservations) { - List instances = r.getInstances(); - for (Instance instance : instances) { - InstanceState state = instance.getState(); - if (state.getCode() == 48) - continue; - hosts.add(instance.getPublicDnsName()); - - } - } - - return hosts; - } - - public List getBoxesCreated() { - return boxesCreated; - } - - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/ClusterSetup.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/ClusterSetup.java deleted file mode 100755 index 94ac12c4b..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/ClusterSetup.java +++ /dev/null @@ -1,122 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.ec2.provision; - -import com.amazonaws.regions.Regions; -import org.deeplearning4j.aws.ec2.Ec2BoxCreator; -import org.kohsuke.args4j.CmdLineException; -import org.kohsuke.args4j.CmdLineParser; -import org.kohsuke.args4j.Option; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.threadly.concurrent.PriorityScheduler; - -import java.util.List; - -/** - * Sets up a DL4J cluster - * @author Adam Gibson - * - */ -public class ClusterSetup { - - @Option(name = "-w", usage = "Number of workers") - private int numWorkers = 1; - @Option(name = "-ami", usage = "Amazon machine image: default, amazon linux (only works with RHEL right now") - private String ami = "ami-fb8e9292"; - @Option(name = "-s", usage = "size of instance: default m1.medium") - private String size = "m3.xlarge"; - @Option(name = "-sg", usage = "security group, this needs to be applyTransformToDestination") - private String securityGroupName; - @Option(name = "-kp", usage = "key pair name, also needs to be applyTransformToDestination.") - private String keyPairName; - @Option(name = "-kpath", - usage = "path to private key - needs to be applyTransformToDestination, this is used to login to amazon.") - private String pathToPrivateKey; - @Option(name = "-wscript", usage = "path to worker script to run, this will allow customization of dependencies") - private String workerSetupScriptPath; - @Option(name = "-mscript", usage = "path to master script to run this will allow customization of the dependencies") - private String masterSetupScriptPath; - @Option(name = "-region", usage = "specify a region") - private String region = Regions.US_EAST_1.getName(); - - private PriorityScheduler as; - - private static final Logger log = LoggerFactory.getLogger(ClusterSetup.class); - - - public ClusterSetup(String[] args) { - CmdLineParser parser = new CmdLineParser(this); - try { - parser.parseArgument(args); - } catch (CmdLineException e) { - parser.printUsage(System.err); - log.error("Unable to parse args", e); - } - - - } - - public void exec() { - //master + workers - Ec2BoxCreator boxCreator = new Ec2BoxCreator(ami, numWorkers, size, securityGroupName, keyPairName); - boxCreator.setRegion(Regions.fromName(region)); - boxCreator.create(); - boxCreator.blockTillAllRunning(); - List hosts = boxCreator.getHosts(); - //provisionMaster(hosts.get(0)); - provisionWorkers(hosts); - - - } - - - - private void provisionWorkers(List workers) { - as = new PriorityScheduler(Runtime.getRuntime().availableProcessors()); - for (final String workerHost : workers) { - try { - as.execute(new Runnable() { - @Override - public void run() { - HostProvisioner uploader = new HostProvisioner(workerHost, "ec2-user"); - try { - uploader.addKeyFile(pathToPrivateKey); - //uploader.runRemoteCommand("sudo hostname " + workerHost); - uploader.uploadAndRun(workerSetupScriptPath, ""); - } catch (Exception e) { - e.printStackTrace(); - } - - } - }); - - } catch (Exception e) { - log.error("Error ", e); - } - } - } - - - /** - * @param args - */ - public static void main(String[] args) { - new ClusterSetup(args).exec(); - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/HostProvisioner.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/HostProvisioner.java deleted file mode 100755 index c8bbd386a..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/HostProvisioner.java +++ /dev/null @@ -1,311 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.ec2.provision; - -import com.jcraft.jsch.*; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.BufferedInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.InputStream; -import java.util.Arrays; -import java.util.Collection; - -/** - * Meant for uploading files to remote servers - * @author Adam Gibson - * - */ -public class HostProvisioner implements UserInfo { - - private String host; - private JSch jsch; - private String user; - private int port = 22; - private String password; - private static final Logger log = LoggerFactory.getLogger(HostProvisioner.class); - - /** - * - * @param host host to connect to (public facing dns) - * @param user the user to connect with (default root otherwise) - * @param password the password to use if any - * @param port the port to connect to(default 22) - */ - public HostProvisioner(String host, String user, String password, int port) { - super(); - this.host = host; - this.user = user; - this.port = port; - this.password = password; - jsch = new JSch(); - - } - - /** - * Connects to port 22 - * @param host host to connect to (public facing dns) - * @param user the user to connect with (default root otherwise) - * @param password the password to use if any - */ - public HostProvisioner(String host, String user, String password) { - this(host, user, password, 22); - } - - /** - * Connects to port 22 - * @param host host to connect to (public facing dns) - * @param user the user to connect with (default root otherwise) - */ - public HostProvisioner(String host, String user) { - this(host, user, "", 22); - } - - /** - * Connects to port 22, user root, with no password - * @param host host to connect to (public facing dns) - */ - public HostProvisioner(String host) { - this(host, "root", "", 22); - } - - - - public void uploadAndRun(String script, String rootDir) throws Exception { - String remoteName = rootDir.isEmpty() ? new File(script).getName() : rootDir + "/" + new File(script).getName(); - upload(new File(script), remoteName); - - String remoteCommand = remoteName.charAt(0) != '/' ? "./" + remoteName : remoteName; - remoteCommand = "chmod +x " + remoteCommand + " && " + remoteCommand; - runRemoteCommand(remoteCommand); - } - - public void runRemoteCommand(String remoteCommand) throws Exception { - Session session = getSession(); - session.connect(); - ChannelExec channel = (ChannelExec) session.openChannel("exec"); - - - channel.setCommand(remoteCommand); - channel.setErrStream(System.err); - channel.setPty(true); - - channel.setOutputStream(System.out); - channel.connect(); - channel.start(); - InputStream input = channel.getInputStream(); - - //start reading the input from the executed commands on the shell - byte[] tmp = new byte[60000]; - while (true) { - while (input.available() > 0) { - int i = input.read(tmp, 0, tmp.length); - if (i < 0) - break; - log.info(new String(tmp, 0, i)); - } - if (channel.isClosed()) { - log.info("exit-status: " + channel.getExitStatus()); - break; - } - - } - - channel.disconnect(); - session.disconnect(); - - - } - - - private Session getSession() throws Exception { - Session session = jsch.getSession(user, host, port); - session.setUserInfo(this); - return session; - } - - /** - * Creates the directory for the file if necessary - * and uploads the file - * @param from the directory to upload from - * @param to the destination directory on the remote server - * @throws Exception - */ - public void uploadForDeployment(String from, String to) throws Exception { - File fromFile = new File(from); - if (!to.isEmpty() && fromFile.isDirectory()) - mkDir(to); - else - upload(from, to); - - - } - - public void addKeyFile(String keyFile) throws Exception { - jsch.addIdentity(keyFile); - } - - //creates the directory to upload to - private void mkDir(String dir) throws Exception { - Session session = getSession(); - session.connect(); - Channel channel = session.openChannel("sftp"); - channel.connect(); - - ChannelSftp c = (ChannelSftp) channel; - if (!fileExists(dir, c)) - c.mkdir(dir); - c.exit(); - session.disconnect(); - } - - private boolean fileExists(String dir, ChannelSftp channel) { - try { - channel.stat(dir); - return true; - } catch (Exception e) { - return false; - } - } - - - //uploads the file or listed files in a directory - private void upload(String fileOrDir, String uploadRootDir) throws Exception { - if (uploadRootDir.isEmpty()) - uploadRootDir = "."; - File origin = new File(fileOrDir); - - if (fileOrDir.endsWith(".tar") || fileOrDir.endsWith(".tar.gz")) { - upload(new File(fileOrDir), uploadRootDir); - untar(uploadRootDir); - } else if (origin.isFile()) { - upload(new File(fileOrDir), uploadRootDir); - } else { - File[] childFiles = origin.listFiles(); - if (childFiles != null) - upload(Arrays.asList(childFiles), uploadRootDir); - - } - } - - private void untar(String targetRemoteFile) throws Exception { - this.runRemoteCommand("tar xvf " + targetRemoteFile); - } - - private void upload(Collection files, String rootDir) throws Exception { - Session session = getSession(); - session.connect(); - Channel channel = session.openChannel("sftp"); - channel.connect(); - - ChannelSftp c = (ChannelSftp) channel; - for (File f : files) { - if (f.isDirectory()) { - log.warn("Skipping " + f.getName()); - continue; - } - - log.info("Uploading " + f.getName()); - BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f)); - c.put(bis, rootDir + "/" + f.getName()); - bis.close(); - - } - - channel.disconnect(); - session.disconnect(); - - - } - - private void upload(File f, String remoteFile) throws Exception { - Session session = getSession(); - int numRetries = 0; - while (numRetries < 3 && !session.isConnected()) { - try { - session.connect(); - } catch (Exception e) { - numRetries++; - } - } - - try { - Channel channel = session.openChannel("sftp"); - - - channel.connect(); - - ChannelSftp c = (ChannelSftp) channel; - - BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f)); - if (this.fileExists(remoteFile, c)) - if (f.isDirectory()) - c.rmdir(remoteFile); - else - c.rm(remoteFile); - c.put(bis, remoteFile); - bis.close(); - c.exit(); - session.disconnect(); - } catch (Exception e) { - log.info("Session was down...trying again", e); - upload(f, remoteFile); - } - } - - - - @Override - public String getPassphrase() { - return this.password; - } - - - @Override - public String getPassword() { - return this.password; - } - - - @Override - public boolean promptPassphrase(String arg0) { - return true; - } - - - @Override - public boolean promptPassword(String arg0) { - return true; - - } - - - @Override - public boolean promptYesNo(String arg0) { - return true; - - } - - - @Override - public void showMessage(String arg0) { - log.info(arg0); - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/EmrConfig.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/EmrConfig.java deleted file mode 100644 index 949738bbd..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/EmrConfig.java +++ /dev/null @@ -1,47 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.emr; - -import com.amazonaws.services.elasticmapreduce.model.Configuration; - -import lombok.*; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - - -@Data -@AllArgsConstructor(access = AccessLevel.PRIVATE) -@NoArgsConstructor -@Builder -public class EmrConfig { - - protected String classification; - protected Map properties; - protected List configs; - - Configuration toAwsConfig() { - Configuration config = new Configuration().withClassification(classification).withProperties(properties); - List subConfigs = new ArrayList<>(); - for (EmrConfig conf : configs){ - subConfigs.add(conf.toAwsConfig()); - } - return config.withConfigurations(subConfigs); - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/SparkEMRClient.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/SparkEMRClient.java deleted file mode 100644 index d179cca09..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/SparkEMRClient.java +++ /dev/null @@ -1,531 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.emr; - -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder; -import com.amazonaws.services.elasticmapreduce.model.*; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.AmazonS3URI; -import com.amazonaws.services.s3.model.PutObjectRequest; -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.RandomStringUtils; -import org.nd4j.linalg.function.Function; - -import java.io.File; -import java.util.*; - -/** - * Configuration for a Spark EMR cluster - */ -@Data -@AllArgsConstructor(access = AccessLevel.PRIVATE) -@NoArgsConstructor -@Slf4j -public class SparkEMRClient { - - protected String sparkClusterName = RandomStringUtils.randomAlphanumeric(12); - protected String sparkAwsRegion = "us-east-1"; - protected String sparkEmrRelease = "emr-5.9.0"; - protected String sparkEmrServiceRole = "EMR_DefaultRole"; - protected List sparkEmrConfigs = Collections.emptyList(); - protected String sparkSubnetId = null; - protected List sparkSecurityGroupIds = Collections.emptyList(); - protected int sparkInstanceCount = 1; - protected String sparkInstanceType = "m3.xlarge"; - protected Optional sparkInstanceBidPrice = Optional.empty(); - protected String sparkInstanceRole = "EMR_EC2_DefaultRole"; - protected String sparkS3JarFolder = "changeme"; - protected int sparkTimeoutDurationMinutes = 90; - - //underlying configs - protected AmazonElasticMapReduceClientBuilder sparkEmrClientBuilder; - protected AmazonS3ClientBuilder sparkS3ClientBuilder; - protected JobFlowInstancesConfig sparkJobFlowInstancesConfig; - protected RunJobFlowRequest sparkRunJobFlowRequest; - protected Function sparkS3PutObjectDecorator; - protected Map sparkSubmitConfs; - - - private static ClusterState[] activeClusterStates = new ClusterState[]{ - ClusterState.RUNNING, - ClusterState.STARTING, - ClusterState.WAITING, - ClusterState.BOOTSTRAPPING}; - - private Optional findClusterWithName(AmazonElasticMapReduce emr, String name) { - List csrl = emr.listClusters((new ListClustersRequest()).withClusterStates(activeClusterStates)).getClusters(); - for (ClusterSummary csr : csrl) { - if (csr.getName().equals(name)) return Optional.of(csr); - } - return Optional.empty(); - } - - /** - * Creates the current cluster - */ - public void createCluster() { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - Optional csr = findClusterWithName(emr, sparkClusterName); - if (csr.isPresent()) { - String msg = String.format("A cluster with name %s and id %s is already deployed", sparkClusterName, csr.get().getId()); - log.error(msg); - throw new IllegalStateException(msg); - } else { - RunJobFlowResult res = emr.runJobFlow(sparkRunJobFlowRequest); - String msg = String.format("Your cluster is launched with name %s and id %s.", sparkClusterName, res.getJobFlowId()); - log.info(msg); - } - - } - - private void logClusters(List csrl) { - if (csrl.isEmpty()) log.info("No cluster found."); - else { - log.info(String.format("%d clusters found.", csrl.size())); - for (ClusterSummary csr : csrl) { - log.info(String.format("Name: %s | Id: %s", csr.getName(), csr.getId())); - } - } - } - - /** - * Lists existing active clusters Names - * - * @return cluster names - */ - public List listActiveClusterNames() { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - List csrl = - emr.listClusters(new ListClustersRequest().withClusterStates(activeClusterStates)).getClusters(); - logClusters(csrl); - List res = new ArrayList<>(csrl.size()); - for (ClusterSummary csr : csrl) res.add(csr.getName()); - return res; - } - - /** - * List existing active cluster IDs - * - * @return cluster IDs - */ - public List listActiveClusterIds() { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - List csrl = - emr.listClusters(new ListClustersRequest().withClusterStates(activeClusterStates)).getClusters(); - logClusters(csrl); - List res = new ArrayList<>(csrl.size()); - for (ClusterSummary csr : csrl) res.add(csr.getId()); - return res; - } - - - /** - * Terminates a cluster - */ - public void terminateCluster() { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - Optional optClusterSum = findClusterWithName(emr, sparkClusterName); - if (!optClusterSum.isPresent()) { - log.error(String.format("The cluster with name %s , requested for deletion, does not exist.", sparkClusterName)); - } else { - String id = optClusterSum.get().getId(); - emr.terminateJobFlows((new TerminateJobFlowsRequest()).withJobFlowIds(id)); - log.info(String.format("The cluster with id %s is terminating.", id)); - } - } - - // The actual job-sumission logic - private void submitJob(AmazonElasticMapReduce emr, String mainClass, List args, Map sparkConfs, File uberJar) throws Exception { - AmazonS3URI s3Jar = new AmazonS3URI(sparkS3JarFolder + "/" + uberJar.getName()); - log.info(String.format("Placing uberJar %s to %s", uberJar.getPath(), s3Jar.toString())); - PutObjectRequest putRequest = sparkS3PutObjectDecorator.apply( - new PutObjectRequest(s3Jar.getBucket(), s3Jar.getKey(), uberJar) - ); - sparkS3ClientBuilder.build().putObject(putRequest); - // The order of these matters - List sparkSubmitArgs = Arrays.asList( - "spark-submit", - "--deploy-mode", - "cluster", - "--class", - mainClass - ); - for (Map.Entry e : sparkConfs.entrySet()) { - sparkSubmitArgs.add(String.format("--conf %s = %s ", e.getKey(), e.getValue())); - } - sparkSubmitArgs.add(s3Jar.toString()); - sparkSubmitArgs.addAll(args); - StepConfig step = new StepConfig() - .withActionOnFailure(ActionOnFailure.CONTINUE) - .withName("Spark step") - .withHadoopJarStep( - new HadoopJarStepConfig() - .withJar("command-runner.jar") - .withArgs(sparkSubmitArgs) - ); - - Optional optCsr = findClusterWithName(emr, sparkClusterName); - if (optCsr.isPresent()) { - ClusterSummary csr = optCsr.get(); - emr.addJobFlowSteps( - new AddJobFlowStepsRequest() - .withJobFlowId(csr.getId()) - .withSteps(step)); - log.info( - String.format("Your job is added to the cluster with id %s.", csr.getId()) - ); - } else { - // If the cluster wasn't started, it's assumed ot be throwaway - List steps = sparkRunJobFlowRequest.getSteps(); - steps.add(step); - RunJobFlowRequest jobFlowRequest = sparkRunJobFlowRequest - .withSteps(steps) - .withInstances(sparkJobFlowInstancesConfig.withKeepJobFlowAliveWhenNoSteps(false)); - - RunJobFlowResult res = emr.runJobFlow(jobFlowRequest); - log.info("Your new cluster's id is %s.", res.getJobFlowId()); - } - - } - - /** - * Submit a Spark Job with a specified main class - */ - public void sparkSubmitJobWithMain(String[] args, String mainClass, File uberJar) throws Exception { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - submitJob(emr, mainClass, Arrays.asList(args), sparkSubmitConfs, uberJar); - } - - private void checkStatus(AmazonElasticMapReduce emr, String clusterId) throws InterruptedException { - log.info("."); - com.amazonaws.services.elasticmapreduce.model.Cluster dcr = - emr.describeCluster((new DescribeClusterRequest()).withClusterId(clusterId)).getCluster(); - String state = dcr.getStatus().getState(); - long timeOutTime = System.currentTimeMillis() + ((long) sparkTimeoutDurationMinutes * 60 * 1000); - - Boolean activated = Arrays.asList(activeClusterStates).contains(ClusterState.fromValue(state)); - Boolean timedOut = System.currentTimeMillis() > timeOutTime; - if (activated && timedOut) { - emr.terminateJobFlows( - new TerminateJobFlowsRequest().withJobFlowIds(clusterId) - ); - log.error("Timeout. Cluster terminated."); - } else if (!activated) { - Boolean hasAbnormalStep = false; - StepSummary stepS = null; - List steps = emr.listSteps(new ListStepsRequest().withClusterId(clusterId)).getSteps(); - for (StepSummary step : steps) { - if (step.getStatus().getState() != StepState.COMPLETED.toString()) { - hasAbnormalStep = true; - stepS = step; - } - } - if (hasAbnormalStep && stepS != null) - log.error(String.format("Cluster %s terminated with an abnormal step, name %s, id %s", clusterId, stepS.getName(), stepS.getId())); - else - log.info("Cluster %s terminated without error.", clusterId); - } else { - Thread.sleep(5000); - checkStatus(emr, clusterId); - } - } - - /** - * Monitor the cluster and terminates when it times out - */ - public void sparkMonitor() throws InterruptedException { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - Optional optCsr = findClusterWithName(emr, sparkClusterName); - if (!optCsr.isPresent()) { - log.error(String.format("The cluster with name %s does not exist.", sparkClusterName)); - } else { - ClusterSummary csr = optCsr.get(); - log.info(String.format("found cluster with id %s, starting monitoring", csr.getId())); - checkStatus(emr, csr.getId()); - } - } - - @Data - public static class Builder { - - protected String sparkClusterName = RandomStringUtils.randomAlphanumeric(12); - protected String sparkAwsRegion = "us-east-1"; - protected String sparkEmrRelease = "emr-5.9.0"; - protected String sparkEmrServiceRole = "EMR_DefaultRole"; - protected List sparkEmrConfigs = Collections.emptyList(); - protected String sparkSubNetid = null; - protected List sparkSecurityGroupIds = Collections.emptyList(); - protected int sparkInstanceCount = 1; - protected String sparkInstanceType = "m3.xlarge"; - protected Optional sparkInstanceBidPrice = Optional.empty(); - protected String sparkInstanceRole = "EMR_EC2_DefaultRole"; - protected String sparkS3JarFolder = "changeme"; - protected int sparkTimeoutDurationMinutes = 90; - - protected AmazonElasticMapReduceClientBuilder sparkEmrClientBuilder; - protected AmazonS3ClientBuilder sparkS3ClientBuilder; - protected JobFlowInstancesConfig sparkJobFlowInstancesConfig; - protected RunJobFlowRequest sparkRunJobFlowRequest; - - // This should allow the user to decorate the put call to add metadata to the jar put command, such as security groups, - protected Function sparkS3PutObjectDecorator = new Function() { - @Override - public PutObjectRequest apply(PutObjectRequest putObjectRequest) { - return putObjectRequest; - } - }; - protected Map sparkSubmitConfs; - - - /** - * Defines the EMR cluster's name - * - * @param clusterName the EMR cluster's name - * @return an EMR cluster builder - */ - public Builder clusterName(String clusterName) { - this.sparkClusterName = clusterName; - return this; - } - - /** - * Defines the EMR cluster's region - * See https://docs.aws.amazon.com/general/latest/gr/rande.html - * - * @param region the EMR cluster's region - * @return an EMR cluster builder - */ - public Builder awsRegion(String region) { - this.sparkAwsRegion = region; - return this; - } - - /** - * Defines the EMR release version to be used in this cluster - * uses a release label - * See https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-4.2.0/emr-release-differences.html#emr-release-label - * - * @param releaseLabel the EMR release label - * @return an EM cluster Builder - */ - public Builder emrRelease(String releaseLabel) { - this.sparkEmrRelease = releaseLabel; - return this; - } - - /** - * Defines the IAM role to be assumed by the EMR service - *

- * https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_create_for-service.html - * - * @param serviceRole the service role - * @return an EM cluster Builder - */ - public Builder emrServiceRole(String serviceRole) { - this.sparkEmrServiceRole = serviceRole; - return this; - } - - /** - * A list of configuration parameters to apply to EMR instances. - * - * @param configs the EMR configurations to apply to this cluster - * @return an EMR cluster builder - */ - public Builder emrConfigs(List configs) { - this.sparkEmrConfigs = configs; - return this; - } - - /** - * The id of the EC2 subnet to be used for this Spark EMR service - * see https://docs.aws.amazon.com/AmazonVPC/latest/UserGuide/VPC_Subnets.html - * - * @param id the subnet ID - * @return an EMR cluster builder - */ - public Builder subnetId(String id) { - this.sparkSubNetid = id; - return this; - } - - /** - * The id of additional security groups this deployment should adopt for both master and slaves - * - * @param securityGroups - * @return an EMR cluster builder - */ - public Builder securityGroupIDs(List securityGroups) { - this.sparkSecurityGroupIds = securityGroups; - return this; - } - - /** - * The number of instances this deployment should comprise of - * - * @param count the number of instances for this cluster - * @rturn an EMR cluster buidler - */ - public Builder instanceCount(int count) { - this.sparkInstanceCount = count; - return this; - } - - /** - * The type of instance this cluster should comprise of - * See https://aws.amazon.com/ec2/instance-types/ - * - * @param instanceType the type of instance for this cluster - * @return an EMR cluster builder - */ - public Builder instanceType(String instanceType) { - this.sparkInstanceType = instanceType; - return this; - } - - /** - * The optional bid value for this cluster's spot instances - * see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/how-spot-instances-work.html - * Uses the on-demand market if empty. - * - * @param optBid the Optional bid price for this cluster's instnces - * @return an EMR cluster Builder - */ - public Builder instanceBidPrice(Optional optBid) { - this.sparkInstanceBidPrice = optBid; - return this; - } - - /** - * The EC2 instance role that this cluster's instances should assume - * see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html - * - * @param role the intended instance role - * @return an EMR cluster builder - */ - public Builder instanceRole(String role) { - this.sparkInstanceRole = role; - return this; - } - - /** - * the S3 folder in which to find the application jar - * - * @param jarfolder the S3 folder in which to find a jar - * @return an EMR cluster builder - */ - public Builder s3JarFolder(String jarfolder) { - this.sparkS3JarFolder = jarfolder; - return this; - } - - /** - * The timeout duration for this Spark EMR cluster, in minutes - * - * @param timeoutMinutes - * @return an EMR cluster builder - */ - public Builder sparkTimeOutDurationMinutes(int timeoutMinutes) { - this.sparkTimeoutDurationMinutes = timeoutMinutes; - return this; - } - - /** - * Creates an EMR Spark cluster deployment - * - * @return a SparkEMRClient - */ - public SparkEMRClient build() { - this.sparkEmrClientBuilder = AmazonElasticMapReduceClientBuilder.standard().withRegion(sparkAwsRegion); - this.sparkS3ClientBuilder = AmazonS3ClientBuilder.standard().withRegion(sparkAwsRegion); - // note this will be kept alive without steps, an arbitrary choice to avoid rapid test-teardown-restart cycles - this.sparkJobFlowInstancesConfig = (new JobFlowInstancesConfig()).withKeepJobFlowAliveWhenNoSteps(true); - if (this.sparkSubNetid != null) - this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withEc2SubnetId(this.sparkSubNetid); - if (!this.sparkSecurityGroupIds.isEmpty()) { - this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withAdditionalMasterSecurityGroups(this.sparkSecurityGroupIds); - this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withAdditionalSlaveSecurityGroups(this.sparkSecurityGroupIds); - } - - InstanceGroupConfig masterConfig = - (new InstanceGroupConfig()).withInstanceCount(1).withInstanceRole(InstanceRoleType.MASTER).withInstanceType(sparkInstanceType); - if (sparkInstanceBidPrice.isPresent()) { - masterConfig = masterConfig.withMarket(MarketType.SPOT).withBidPrice(sparkInstanceBidPrice.get().toString()); - } else { - masterConfig = masterConfig.withMarket(MarketType.ON_DEMAND); - } - - int slaveCount = sparkInstanceCount - 1; - InstanceGroupConfig slaveConfig = - (new InstanceGroupConfig()).withInstanceCount(slaveCount).withInstanceRole(InstanceRoleType.CORE).withInstanceRole(sparkInstanceType); - if (sparkInstanceBidPrice.isPresent()) { - slaveConfig = slaveConfig.withMarket(MarketType.SPOT).withBidPrice(sparkInstanceBidPrice.get().toString()); - } else { - slaveConfig = slaveConfig.withMarket(MarketType.ON_DEMAND); - } - if (slaveCount > 0) { - this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withInstanceGroups(Arrays.asList(masterConfig, slaveConfig)); - } else { - this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withInstanceGroups(slaveConfig); - } - - this.sparkRunJobFlowRequest = new RunJobFlowRequest(); - if (!sparkEmrConfigs.isEmpty()) { - List emrConfigs = new ArrayList<>(); - for (EmrConfig config : sparkEmrConfigs) { - emrConfigs.add(config.toAwsConfig()); - } - this.sparkRunJobFlowRequest = this.sparkRunJobFlowRequest.withConfigurations(emrConfigs); - } - this.sparkRunJobFlowRequest = - this.sparkRunJobFlowRequest.withName(sparkClusterName).withApplications((new Application()).withName("Spark")) - .withReleaseLabel(sparkEmrRelease) - .withServiceRole(sparkEmrServiceRole) - .withJobFlowRole(sparkInstanceRole) - .withInstances(this.sparkJobFlowInstancesConfig); - - return new SparkEMRClient( - this.sparkClusterName, - this.sparkAwsRegion, - this.sparkEmrRelease, - this.sparkEmrServiceRole, - this.sparkEmrConfigs, - this.sparkSubNetid, - this.sparkSecurityGroupIds, - this.sparkInstanceCount, - this.sparkInstanceType, - this.sparkInstanceBidPrice, - this.sparkInstanceRole, - this.sparkS3JarFolder, - this.sparkTimeoutDurationMinutes, - this.sparkEmrClientBuilder, - this.sparkS3ClientBuilder, - this.sparkJobFlowInstancesConfig, - this.sparkRunJobFlowRequest, - this.sparkS3PutObjectDecorator, - this.sparkSubmitConfs - ); - } - - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/BaseS3.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/BaseS3.java deleted file mode 100755 index 13175505b..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/BaseS3.java +++ /dev/null @@ -1,117 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3; - -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.auth.PropertiesCredentials; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.AmazonEC2Client; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3Client; - -import java.io.File; -import java.io.InputStream; - - -/** - * The S3 Credentials works via discovering the credentials - * from the system properties (passed in via -D or System wide) - * If you invoke the JVM with -Dorg.deeplearning4j.aws.accessKey=YOUR_ACCESS_KEY - * and -Dorg.deeplearning4j.aws.accessSecret=YOUR_SECRET_KEY - * this will pick up the credentials from there, otherwise it will also attempt to look in - * the system environment for the following variables: - * - * - * AWS_ACCESS_KEY_ID - * AWS_SECRET_ACCESS_KEY - * @author Adam Gibson - * - */ -public abstract class BaseS3 { - - - /** - * - */ - protected static final long serialVersionUID = -2280107690193651289L; - protected String accessKey; - protected String secretKey; - protected AWSCredentials creds; - public final static String ACCESS_KEY = "org.deeplearning4j.aws.accessKey"; - public final static String ACCESS_SECRET = "org.deeplearning4j.aws.accessSecret"; - public final static String AWS_ACCESS_KEY = "AWS_ACCESS_KEY"; //"AWS_ACCESS_KEY_ID"; - public final static String AWS_SECRET_KEY = "AWS_SECRET_KEY"; //"AWS_SECRET_ACCESS_KEY"; - - - protected void findCreds() { - if (System.getProperty(ACCESS_KEY) != null && System.getProperty(ACCESS_SECRET) != null) { - accessKey = System.getProperty(ACCESS_KEY); - secretKey = System.getProperty(ACCESS_SECRET); - } - - else if (System.getenv(AWS_ACCESS_KEY) != null && System.getenv(AWS_SECRET_KEY) != null) { - accessKey = System.getenv(AWS_ACCESS_KEY); - secretKey = System.getenv(AWS_SECRET_KEY); - } - } - - public BaseS3() { - findCreds(); - if (accessKey != null && secretKey != null) - creds = new BasicAWSCredentials(accessKey, secretKey); - if (creds == null) - throw new IllegalStateException("Unable to find ec2 credentials"); - } - - public BaseS3(File file) throws Exception { - if (accessKey != null && secretKey != null) - creds = new BasicAWSCredentials(accessKey, secretKey); - else - creds = new PropertiesCredentials(file); - - - } - - public BaseS3(InputStream is) throws Exception { - if (accessKey != null && secretKey != null) - creds = new BasicAWSCredentials(accessKey, secretKey); - else - creds = new PropertiesCredentials(is); - - - } - - public AWSCredentials getCreds() { - return creds; - } - - public void setCreds(AWSCredentials creds) { - this.creds = creds; - } - - public AmazonS3 getClient() { - return new AmazonS3Client(creds); - } - - public AmazonEC2 getEc2() { - - return new AmazonEC2Client(creds); - } - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BaseS3DataSetIterator.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BaseS3DataSetIterator.java deleted file mode 100755 index 5532f79d2..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BaseS3DataSetIterator.java +++ /dev/null @@ -1,83 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3.reader; - -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; - -import java.io.InputStream; -import java.util.Iterator; -import java.util.List; - -/** - * baseline data applyTransformToDestination iterator for - * @author Adam Gibson - * - */ -public abstract class BaseS3DataSetIterator implements DataSetIterator { - - /** - * - */ - private static final long serialVersionUID = 885205002006822431L; - private S3Downloader downloader; - private List buckets; - private int currBucket; - private Iterator currIterator; - - public BaseS3DataSetIterator() { - downloader = new S3Downloader(); - buckets = downloader.buckets(); - currBucket = 0; - currIterator = downloader.iterateBucket(buckets.get(currBucket)); - } - - - - public InputStream nextObject() { - if (currIterator.hasNext()) - return currIterator.next(); - else if (currBucket < buckets.size()) { - currBucket++; - currIterator = downloader.iterateBucket(buckets.get(currBucket)); - return currIterator.next(); - } - - return null; - } - - - - @Override - public boolean hasNext() { - return currBucket < buckets.size() && currIterator.hasNext(); - } - - - - public String currBucket() { - return buckets.get(currBucket); - } - - - - public void nextBucket() { - currBucket++; - } - - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketIterator.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketIterator.java deleted file mode 100755 index 228fc5cfc..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketIterator.java +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3.reader; - -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.S3ObjectSummary; - -import java.io.InputStream; -import java.util.Iterator; -import java.util.List; - -/** - * Iterator over individual S3 bucket - * @author Adam Gibson - * - */ -public class BucketIterator implements Iterator { - - private S3Downloader s3; - private String bucket; - private ObjectListing currList; - private List currObjects; - private int currObject; - - - - public BucketIterator(String bucket) { - this(bucket, null); - - } - - - public BucketIterator(String bucket, S3Downloader s3) { - this.bucket = bucket; - - if (s3 == null) - this.s3 = new S3Downloader(); - else - this.s3 = s3; - currList = this.s3.listObjects(bucket); - currObjects = currList.getObjectSummaries(); - - } - - - - @Override - public boolean hasNext() { - return currObject < currObjects.size(); - } - - @Override - public InputStream next() { - if (currObject < currObjects.size()) { - InputStream ret = s3.objectForKey(bucket, currObjects.get(currObject).getKey()); - currObject++; - return ret; - } else if (currList.isTruncated()) { - currList = s3.nextList(currList); - currObjects = currList.getObjectSummaries(); - currObject = 0; - - InputStream ret = s3.objectForKey(bucket, currObjects.get(currObject).getKey()); - - currObject++; - return ret; - } - - - throw new IllegalStateException("Indeterminate state"); - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/S3Downloader.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/S3Downloader.java deleted file mode 100755 index 980a3f3e9..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/S3Downloader.java +++ /dev/null @@ -1,178 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3.reader; - -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.*; -import com.amazonaws.services.s3.transfer.MultipleFileDownload; -import com.amazonaws.services.s3.transfer.TransferManager; -import org.apache.commons.io.IOUtils; -import org.deeplearning4j.aws.s3.BaseS3; - -import java.io.*; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; - -/** - * Downloads files from S3 - * @author Adam Gibson - * - */ -public class S3Downloader extends BaseS3 { - - - /** - * Return the keys for a bucket - * @param bucket the bucket to get the keys for - * @return the bucket's keys - */ - public List keysForBucket(String bucket) { - AmazonS3 s3 = getClient(); - List ret = new ArrayList<>(); - ListObjectsRequest listObjectsRequest = new ListObjectsRequest().withBucketName(bucket); - ObjectListing objectListing; - - do { - objectListing = s3.listObjects(listObjectsRequest); - for (S3ObjectSummary objectSummary : objectListing.getObjectSummaries()) { - ret.add(objectSummary.getKey()); - } - listObjectsRequest.setMarker(objectListing.getNextMarker()); - } while (objectListing.isTruncated()); - - return ret; - } - - /** - * Returns the list of buckets in s3 - * @return the list of buckets - */ - public List buckets() { - List ret = new ArrayList<>(); - AmazonS3 s3 = getClient(); - List buckets = s3.listBuckets(); - for (Bucket b : buckets) - ret.add(b.getName()); - return ret; - } - - /** - * Iterate over individual buckets. - * Returns input streams to each object. - * It is your responsibility to close the input streams - * @param bucket the bucket to iterate over - * @return an iterator over the objects in an s3 bucket - */ - public Iterator iterateBucket(String bucket) { - return new BucketIterator(bucket, this); - } - - /** - * Iterator style one list at a time - * @param list the list to get the next batch for - * @return the next batch of objects or null if - * none are left - */ - public ObjectListing nextList(ObjectListing list) { - AmazonS3 s3 = getClient(); - if (list.isTruncated()) - return s3.listNextBatchOfObjects(list); - return null; - } - - /** - * Simple way of retrieving the listings for a bucket - * @param bucket the bucket to retrieve listings for - * @return the object listing for this bucket - */ - public ObjectListing listObjects(String bucket) { - AmazonS3 s3 = getClient(); - ObjectListing list = s3.listObjects(bucket); - return list; - } - - /** - * Paginates through a bucket's keys invoking the listener - * at each key - * @param bucket the bucket to iterate - * @param listener the listener - */ - public void paginate(String bucket, BucketKeyListener listener) { - AmazonS3 s3 = getClient(); - ObjectListing list = s3.listObjects(bucket); - for (S3ObjectSummary summary : list.getObjectSummaries()) { - if (listener != null) - listener.onKey(s3, bucket, summary.getKey()); - } - - while (list.isTruncated()) { - list = s3.listNextBatchOfObjects(list); - for (S3ObjectSummary summary : list.getObjectSummaries()) { - if (listener != null) - listener.onKey(s3, bucket, summary.getKey()); - } - } - - - } - - - /** - * Returns an input stream for the given bucket and key - * @param bucket the bucket to retrieve from - * @param key the key of the objec t - * @return an input stream to the object - */ - public InputStream objectForKey(String bucket, String key) { - AmazonS3 s3 = getClient(); - S3Object obj = s3.getObject(bucket, key); - InputStream is = obj.getObjectContent(); - return is; - } - - - public void download(String bucket, String key, File to) throws IOException { - AmazonS3 s3 = getClient(); - S3Object obj = s3.getObject(bucket, key); - InputStream is = obj.getObjectContent(); - BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(to)); - IOUtils.copy(is, bos); - bos.close(); - is.close(); - obj.close(); - } - - public void download(String bucket, String key, OutputStream to) throws IOException { - AmazonS3 s3 = getClient(); - S3Object obj = s3.getObject(bucket, key); - InputStream is = obj.getObjectContent(); - BufferedOutputStream bos = new BufferedOutputStream(to); - - IOUtils.copy(is, bos); - bos.close(); - is.close(); - obj.close(); - } - - public MultipleFileDownload downloadFolder(String bucketName, String keyPrefix, File folderPath) { - TransferManager transfer = new TransferManager(getClient()); - return transfer.downloadDirectory(bucketName, keyPrefix, folderPath); - } - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/uploader/S3Uploader.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/uploader/S3Uploader.java deleted file mode 100755 index eacc71386..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/uploader/S3Uploader.java +++ /dev/null @@ -1,172 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3.uploader; - -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3Client; -import com.amazonaws.services.s3.model.*; -import com.amazonaws.services.s3.transfer.MultipleFileUpload; -import com.amazonaws.services.s3.transfer.TransferManager; -import org.deeplearning4j.aws.s3.BaseS3; - -import java.io.File; -import java.util.ArrayList; -import java.util.List; - -/** - * Uploads files to S3 - * - * @see {@link BaseS3} - * @author Adam Gibson - * - */ -public class S3Uploader extends BaseS3 { - - - /** - * Multi part upload for big files - * @param file the file to upload - * @param bucketName the bucket name to upload - */ - public void multiPartUpload(File file, String bucketName) { - AmazonS3 client = new AmazonS3Client(creds); - bucketName = ensureValidBucketName(bucketName); - - List buckets = client.listBuckets(); - for (Bucket b : buckets) - if (b.getName().equals(bucketName)) { - doMultiPart(client, bucketName, file); - return; - } - - //bucket didn't exist: create it - client.createBucket(bucketName); - doMultiPart(client, bucketName, file); - } - - /** - * Upload the file to the bucket. - * Will create the bucket if it hasn't already been created - * @param file the file to upload - * @param bucketName the name of the bucket - */ - public void upload(File file, String bucketName) { - AmazonS3 client = new AmazonS3Client(creds); - bucketName = ensureValidBucketName(bucketName); - - List buckets = client.listBuckets(); - for (Bucket b : buckets) - if (b.getName().equals(bucketName)) { - client.putObject(bucketName, file.getName(), file); - return; - } - - //bucket didn't exist: create it - client.createBucket(bucketName); - client.putObject(bucketName, file.getName(), file); - - } - - private void doMultiPart(AmazonS3 s3Client, String bucketName, File file) { - // Create a list of UploadPartResponse objects. You get one of these - // for each part upload. - List partETags = new ArrayList<>(); - - // Step 1: Initialize. - InitiateMultipartUploadRequest initRequest = new InitiateMultipartUploadRequest(bucketName, file.getName()); - InitiateMultipartUploadResult initResponse = s3Client.initiateMultipartUpload(initRequest); - - long contentLength = file.length(); - long partSize = 5242880; // Set part size to 5 MB. - - try { - // Step 2: Upload parts. - long filePosition = 0; - for (int i = 1; filePosition < contentLength; i++) { - // Last part can be less than 5 MB. Adjust part size. - partSize = Math.min(partSize, (contentLength - filePosition)); - - // Create request to upload a part. - UploadPartRequest uploadRequest = new UploadPartRequest().withBucketName(bucketName) - .withKey(file.getName()).withUploadId(initResponse.getUploadId()).withPartNumber(i) - .withFileOffset(filePosition).withFile(file).withPartSize(partSize); - - // Upload part and add response to our list. - partETags.add(s3Client.uploadPart(uploadRequest).getPartETag()); - - filePosition += partSize; - } - - // Step 3: Complete. - CompleteMultipartUploadRequest compRequest = new CompleteMultipartUploadRequest(bucketName, file.getName(), - initResponse.getUploadId(), partETags); - - s3Client.completeMultipartUpload(compRequest); - } catch (Exception e) { - s3Client.abortMultipartUpload( - new AbortMultipartUploadRequest(bucketName, file.getName(), initResponse.getUploadId())); - } - } - - private String ensureValidBucketName(String bucketName) { - String formatted = bucketName.replaceAll("\\s+", "_"); - int length = bucketName.length(); - if (length >= 62) - length = 62; - formatted = formatted.substring(0, length); - formatted = formatted.replace(".", "d"); - formatted = formatted.toLowerCase(); - if (formatted.endsWith("-")) - formatted = formatted.substring(0, length - 1); - - return formatted; - } - - public void upload(File file, String name, String bucketName) { - AmazonS3 client = getClient(); - bucketName = ensureValidBucketName(bucketName); - List buckets = client.listBuckets(); - // ObjectMetadata med = new ObjectMetadata(); - // med.setContentLength(fileLength); - for (Bucket b : buckets) - if (b.getName().equals(bucketName)) { - //client.putObject(bucketName, name, is, med); - client.putObject(new PutObjectRequest(bucketName, name, file)); - return; - } - - //bucket didn't exist: createComplex it - client.createBucket(bucketName); - //client.putObject(bucketName, name, is, med); - client.putObject(new PutObjectRequest(bucketName, name, file)); - } - - - public MultipleFileUpload uploadFolder(String bucketName, String keyPrefix, File folderPath, - boolean includeSubDir) { - TransferManager transfer = new TransferManager(getClient()); - return transfer.uploadDirectory(bucketName, keyPrefix, folderPath, includeSubDir); - } - - public MultipleFileUpload uploadFileList(String bucketName, File folderPath, List fileList, - String keyPrefix) { - TransferManager transfer = new TransferManager(getClient()); - return transfer.uploadFileList(bucketName, keyPrefix, folderPath, fileList); - } - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/resources/provision-master.sh b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/resources/provision-master.sh deleted file mode 100755 index 7f7285bb5..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/resources/provision-master.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash - -################################################################################ -# Copyright (c) 2015-2018 Skymind, Inc. -# -# This program and the accompanying materials are made available under the -# terms of the Apache License, Version 2.0 which is available at -# https://www.apache.org/licenses/LICENSE-2.0. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -# -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -sudo yum -y install blas java-1.7.0-openjdk-devel - -if [ ! -f "/opt/dl4j" ]; -then - sudo mkdir /opt/dl4j - sudo yum -y install git - git clone https://github.com/agibsonccc/java-deeplearning - - wget http://www.trieuvan.com/apache/maven/maven-3/3.1.1/binaries/apache-maven-3.1.1-bin.tar.gz - sudo mv apache-maven-3.1.1-bin.tar.gz /opt - cd /opt && sudo tar xvf apache-maven-3.1.1-bin.tar.gz && sudo mv apache-maven-3.1.1 /opt/mvn - cd /home/ec2-user/java-deeplearning/ && /opt/mvn/bin/mvn -DskipTests clean install - echo "Printing distribution" - ls /home/ec2-user/java-deeplearning/deeplearning4j-distribution/target - echo "Before moving distribution" - sudo mv deeplearning4j-distribution/target/deeplearning4j-dist.tar.gz /opt - echo "Moving distribution to opt directory..." - echo "Moving in to opt directory" - cd /opt - - sudo tar xzvf deeplearning4j-dist.tar.gz - #sudo mkdir /opt/dl4j - echo "Moving jars in to /opt/dl4j/..." - sudo mv *.jar /opt/dl4j -fi - - diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml index 97515cf5e..2c7a94de8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml @@ -86,6 +86,13 @@ logback-classic test + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/BaseDL4JTest.java deleted file mode 100644 index 8e087cc2f..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/BaseDL4JTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.parallelism.parameterserver; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java index 07b5b41a7..beb9af5b4 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java @@ -17,6 +17,7 @@ package org.deeplearning4j.parallelism.parameterserver; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml index 08eed7f15..3c083d40d 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml @@ -90,6 +90,13 @@ ${project.version} test + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/BaseDL4JTest.java deleted file mode 100644 index f97073042..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/BaseDL4JTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.parallelism; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java index 0f3e95930..d089781f1 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java @@ -17,6 +17,7 @@ package org.deeplearning4j.parallelism; import lombok.val; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java index 4cec0eed4..219573f0a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java @@ -19,6 +19,7 @@ package org.deeplearning4j.parallelism; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; @@ -27,7 +28,8 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.junit.Ignore; +import org.junit.*; +import org.junit.rules.Timeout; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.io.ClassPathResource; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; @@ -38,9 +40,6 @@ import org.deeplearning4j.parallelism.inference.InferenceObservable; import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObserver; import org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable; import org.deeplearning4j.util.ModelSerializer; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -66,6 +65,9 @@ public class ParallelInferenceTest extends BaseDL4JTest { private static MultiLayerNetwork model; private static DataSetIterator iterator; + @Rule + public Timeout timeout = Timeout.seconds(300); + @Before public void setUp() throws Exception { if (model == null) { @@ -483,7 +485,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { List exp = new ArrayList<>(); Random r = new Random(); - for (int i = 0; i < 500; i++) { + int runs = isIntegrationTests() ? 500 : 30; + for (int i = 0; i < runs; i++) { int[] shape = defaultSize; if (r.nextDouble() < 0.4) { shape = new int[]{r.nextInt(5) + 1, 10, r.nextInt(10) + 1}; @@ -595,7 +598,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { List arrs = new ArrayList<>(); List exp = new ArrayList<>(); Random r = new Random(); - for( int i=0; i<500; i++ ){ + int runs = isIntegrationTests() ? 500 : 20; + for( int i=0; i in = new ArrayList<>(); List inMasks = new ArrayList<>(); List exp = new ArrayList<>(); - for (int i = 0; i < 100; i++) { + int nRuns = isIntegrationTests() ? 100 : 10; + for (int i = 0; i < nRuns; i++) { int currTSLength = (randomTSLength ? 1 + r.nextInt(tsLength) : tsLength); int currNumEx = 1 + r.nextInt(3); INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn, currTSLength}); @@ -845,6 +852,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { List in = new ArrayList<>(); List exp = new ArrayList<>(); + int runs = isIntegrationTests() ? 100 : 20; for (int i = 0; i < 100; i++) { int currNumEx = 1 + r.nextInt(3); INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn}); diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java index 52a0c0109..cea06265a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java @@ -17,6 +17,7 @@ package org.deeplearning4j.parallelism; import lombok.val; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.eval.Evaluation; @@ -61,8 +62,8 @@ public class ParallelWrapperTest extends BaseDL4JTest { int seed = 123; log.info("Load data...."); - DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 100); - DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 10); + DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 15); + DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 4); assertTrue(mnistTrain.hasNext()); val t0 = mnistTrain.next(); diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java index ea48549ba..9593a0799 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java @@ -16,6 +16,7 @@ package org.deeplearning4j.parallelism; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorageRouter; import org.deeplearning4j.api.storage.listener.RoutingIterationListener; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java index 4aeb85acd..ac2b018e2 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java @@ -16,6 +16,7 @@ package org.deeplearning4j.parallelism; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java index cce4f490a..160d4df58 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java @@ -16,6 +16,7 @@ package org.deeplearning4j.parallelism; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java index a8eca6a56..c96ca4a19 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java @@ -27,7 +27,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.parallelism.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.parallelism.ParallelWrapper; import org.deeplearning4j.parallelism.trainer.SymmetricTrainer; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java index cf6fe92de..0258caac9 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java @@ -27,7 +27,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.parallelism.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.parallelism.ParallelWrapper; import org.deeplearning4j.parallelism.trainer.SymmetricTrainer; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservableTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservableTest.java index 5b5173b20..facf506d6 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservableTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservableTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.parallelism.inference.observers; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.parallelism.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.junit.After; import org.junit.Before; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java index 54bc7f4f6..ae6672b47 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java @@ -27,7 +27,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.parallelism.BaseDL4JTest; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.util.ModelSerializer; import org.junit.Rule; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-scaleout/pom.xml b/deeplearning4j/deeplearning4j-scaleout/pom.xml index 2c192cc10..539aa3ef7 100644 --- a/deeplearning4j/deeplearning4j-scaleout/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/pom.xml @@ -28,7 +28,6 @@ DeepLearning4j-scaleout-parent - deeplearning4j-aws spark deeplearning4j-scaleout-parallelwrapper deeplearning4j-scaleout-parallelwrapper-parameter-server diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml index 42e799b69..3eafbb9e2 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml @@ -47,6 +47,12 @@ org.nd4j nd4j-parameter-server-node_2.11 ${nd4j.version} + + + net.jpountz.lz4 + lz4 + + junit @@ -60,6 +66,13 @@ ${spark.version} provided + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java index 5f71ce497..363b4e293 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java @@ -19,6 +19,7 @@ package org.deeplearning4j.spark.models.sequencevectors; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.sequencevectors.sequence.Sequence; import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement; import org.deeplearning4j.models.word2vec.VocabWord; @@ -41,7 +42,7 @@ import static org.junit.Assert.assertNotEquals; /** * @author raver119@gmail.com */ -public class SparkSequenceVectorsTest { +public class SparkSequenceVectorsTest extends BaseDL4JTest { protected static List> sequencesCyclic; private JavaSparkContext sc; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java index 70ecc0dbe..604181109 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java @@ -16,6 +16,7 @@ package org.deeplearning4j.spark.models.sequencevectors.export; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.word2vec.VocabWord; import org.junit.Before; import org.junit.Test; @@ -26,7 +27,7 @@ import static org.junit.Assert.assertEquals; /** * @author raver119@gmail.com */ -public class ExportContainerTest { +public class ExportContainerTest extends BaseDL4JTest { @Before public void setUp() throws Exception { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java index 809d74138..82a04eab8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java @@ -20,6 +20,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.VoidFunction; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement; @@ -46,7 +47,7 @@ import static org.junit.Assert.*; * * @author raver119@gmail.com */ -public class SparkWord2VecTest { +public class SparkWord2VecTest extends BaseDL4JTest { private static List sentences; private JavaSparkContext sc; diff --git a/datavec/datavec-camel/src/test/resources/log4j.properties b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties old mode 100644 new mode 100755 similarity index 58% rename from datavec/datavec-camel/src/test/resources/log4j.properties rename to deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties index 09f815522..5d1edb39f --- a/datavec/datavec-camel/src/test/resources/log4j.properties +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties @@ -1,5 +1,5 @@ ################################################################################ -# Copyright (c) 2015-2018 Skymind, Inc. +# Copyright (c) 2015-2019 Skymind, Inc. # # This program and the accompanying materials are made available under the # terms of the Apache License, Version 2.0 which is available at @@ -14,17 +14,18 @@ # SPDX-License-Identifier: Apache-2.0 ################################################################################ -# -# The logging properties used -# -log4j.rootLogger=INFO, out +log4j.rootLogger=ERROR, Console +log4j.appender.Console=org.apache.log4j.ConsoleAppender +log4j.appender.Console.layout=org.apache.log4j.PatternLayout +log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n -# uncomment the following line to turn on Camel debugging -#log4j.logger.org.apache.camel=DEBUG +log4j.appender.org.springframework=DEBUG +log4j.appender.org.deeplearning4j=DEBUG +log4j.appender.org.nd4j=DEBUG + +log4j.logger.org.springframework=INFO +log4j.logger.org.deeplearning4j=DEBUG +log4j.logger.org.nd4j=DEBUG +log4j.logger.org.apache.spark=WARN -# CONSOLE appender not used by default -log4j.appender.out=org.apache.log4j.ConsoleAppender -log4j.appender.out.layout=org.apache.log4j.PatternLayout -log4j.appender.out.layout.ConversionPattern=[%30.30t] %-30.30c{1} %-5p %m%n -#log4j.appender.out.layout.ConversionPattern=%d [%-15.15t] %-5p %-30.30c{1} - %m%n diff --git a/deeplearning4j/dl4j-perf/src/test/resources/logback-test.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml similarity index 93% rename from deeplearning4j/dl4j-perf/src/test/resources/logback-test.xml rename to deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml index c6f89b60a..9dec22fae 100644 --- a/deeplearning4j/dl4j-perf/src/test/resources/logback-test.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml @@ -16,6 +16,8 @@ + + logs/application.log @@ -33,12 +35,13 @@ - + + @@ -47,4 +50,4 @@ - \ No newline at end of file + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml index 0a92d19ab..c4e8dc7ab 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml @@ -65,6 +65,13 @@ jackson-module-scala_2.11 2.6.7.1 + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java index 3c311eed3..7f1b682ab 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java @@ -20,7 +20,7 @@ import lombok.Getter; import lombok.NonNull; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.nd4j.linalg.api.buffer.FloatBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -63,7 +63,7 @@ public class NegativeHolder implements Serializable { protected void makeTable(int tableSize, double power) { int vocabSize = vocab.numWords(); - table = Nd4j.create(new FloatBuffer(tableSize)); + table = Nd4j.create(DataType.FLOAT, tableSize); double trainWordsPow = 0.0; for (String word : vocab.words()) { trainWordsPow += Math.pow(vocab.wordFrequency(word), power); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java index 152ef4db5..475572edd 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java @@ -18,6 +18,7 @@ package org.deeplearning4j.spark.text; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables; import org.junit.After; import org.junit.Before; @@ -30,7 +31,7 @@ import java.util.Map; /** * Created by agibsonccc on 1/23/15. */ -public abstract class BaseSparkTest implements Serializable { +public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable { protected transient JavaSparkContext sc; @Before diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties new file mode 100755 index 000000000..5d1edb39f --- /dev/null +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +################################################################################ +# Copyright (c) 2015-2019 Skymind, Inc. +# +# This program and the accompanying materials are made available under the +# terms of the Apache License, Version 2.0 which is available at +# https://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +log4j.rootLogger=ERROR, Console +log4j.appender.Console=org.apache.log4j.ConsoleAppender +log4j.appender.Console.layout=org.apache.log4j.PatternLayout +log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n + +log4j.appender.org.springframework=DEBUG +log4j.appender.org.deeplearning4j=DEBUG +log4j.appender.org.nd4j=DEBUG + +log4j.logger.org.springframework=INFO +log4j.logger.org.deeplearning4j=DEBUG +log4j.logger.org.nd4j=DEBUG +log4j.logger.org.apache.spark=WARN + + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml new file mode 100644 index 000000000..9dec22fae --- /dev/null +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml @@ -0,0 +1,53 @@ + + + + + + + + logs/application.log + + %date - [%level] - from %logger in %thread + %n%message%n%xException%n + + + + + + %logger{15} - %message%n%xException{5} + + + + + + + + + + + + + + + + + + + + + + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml index fc1e96ec0..daf0dd9b7 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml @@ -66,6 +66,13 @@ ${spark.version} provided + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java index a90ce4b8c..ccab68e9e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java @@ -19,6 +19,7 @@ package org.deeplearning4j.spark.parameterserver; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; @@ -41,7 +42,7 @@ import java.util.Random; /** * Created by agibsonccc on 1/23/15. */ -public abstract class BaseSparkTest implements Serializable { +public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable { protected transient JavaSparkContext sc; protected transient INDArray labels; protected transient INDArray input; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties index 29de0de02..5d1edb39f 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties @@ -1,4 +1,3 @@ - ################################################################################ # Copyright (c) 2015-2019 Skymind, Inc. # @@ -21,10 +20,12 @@ log4j.appender.Console.layout=org.apache.log4j.PatternLayout log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n log4j.appender.org.springframework=DEBUG -log4j.appender.org.deeplearning4j=INFO +log4j.appender.org.deeplearning4j=DEBUG log4j.appender.org.nd4j=DEBUG log4j.logger.org.springframework=INFO log4j.logger.org.deeplearning4j=DEBUG log4j.logger.org.nd4j=DEBUG +log4j.logger.org.apache.spark=WARN + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml index 47c108b71..9dec22fae 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml @@ -1,5 +1,5 @@ org.webjars.npm @@ -417,6 +424,16 @@ weaverjs 1.2.0 + + org.webjars + explorercanvas + r3-1 + + + org.webjars + bootstrap + 2.2.2-1 + diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java index a5469bd1d..0b81e45b4 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java @@ -17,6 +17,7 @@ package org.deeplearning4j.ui; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.api.storage.Persistable; import org.deeplearning4j.api.storage.StorageMetaData; import org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter; @@ -51,7 +52,7 @@ import static org.junit.Assert.assertEquals; * Created by Alex on 10/11/2016. */ @Ignore -public class TestRemoteReceiver { +public class TestRemoteReceiver extends BaseDL4JTest { @Test @Ignore diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java index a9da39dbc..4ba24eafa 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java @@ -18,6 +18,7 @@ package org.deeplearning4j.ui; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.ui.api.UIServer; import org.junit.Ignore; import org.junit.Rule; @@ -35,7 +36,7 @@ import java.util.Arrays; @Ignore @Slf4j -public class TestSameDiffUI { +public class TestSameDiffUI extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java index fc5c3c4ac..43e6c76df 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java @@ -18,6 +18,7 @@ package org.deeplearning4j.ui; import org.apache.commons.io.IOUtils; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; @@ -56,7 +57,7 @@ import static org.junit.Assert.*; * Created by Alex on 08/10/2016. */ @Ignore -public class TestVertxUI { +public class TestVertxUI extends BaseDL4JTest { @Before public void setUp() throws Exception { UIServer.stopInstance(); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java index b640be8c7..1eeceb2aa 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java @@ -18,6 +18,7 @@ package org.deeplearning4j.ui; import io.netty.handler.codec.http.HttpResponseStatus; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; @@ -52,7 +53,7 @@ import static org.junit.Assert.*; * @author Tamas Fenyvesi */ @Ignore -public class TestVertxUIMultiSession { +public class TestVertxUIMultiSession extends BaseDL4JTest { @Before public void setUp() throws Exception { UIServer.stopInstance(); diff --git a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/DiskBasedQueue.java b/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/DiskBasedQueue.java deleted file mode 100644 index 69999abf7..000000000 --- a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/DiskBasedQueue.java +++ /dev/null @@ -1,193 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.util; - -import org.apache.commons.io.FileUtils; -import org.nd4j.linalg.util.SerializationUtils; - -import java.io.File; -import java.io.IOException; -import java.io.Serializable; -import java.util.Collection; -import java.util.Iterator; -import java.util.Queue; -import java.util.UUID; -import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicBoolean; - -/** - * Naive disk based queue for storing items on disk. - * Only meant for poll and adding items. - * @author Adam Gibson - */ -public class DiskBasedQueue implements Queue, Serializable { - - private File dir; - private Queue paths = new ConcurrentLinkedDeque<>(); - private AtomicBoolean running = new AtomicBoolean(true); - private Queue save = new ConcurrentLinkedDeque<>(); - - public DiskBasedQueue() { - this(".queue"); - } - - public DiskBasedQueue(String path) { - this(new File(path)); - - } - - public DiskBasedQueue(File dir) { - this.dir = dir; - if (!dir.exists() && dir.isDirectory()) { - throw new IllegalArgumentException("Illegal queue: must be a directory"); - } - - if (!dir.exists()) - dir.mkdirs(); - if (dir.listFiles() != null && dir.listFiles().length > 1) - try { - FileUtils.deleteDirectory(dir); - } catch (IOException e) { - e.printStackTrace(); - } - - - dir.mkdir(); - - Thread t = Executors.defaultThreadFactory().newThread(new Runnable() { - @Override - public void run() { - while (running.get()) { - while (!save.isEmpty()) - addAndSave(save.poll()); - - ThreadUtils.uncheckedSleep(1000); - } - } - }); - t.setName("DiskBasedQueueSaver"); - t.setDaemon(true); - t.start(); - } - - @Override - public int size() { - return paths.size(); - } - - @Override - public boolean isEmpty() { - return paths.isEmpty(); - } - - @Override - public boolean contains(Object o) { - throw new UnsupportedOperationException(); - - } - - @Override - public Iterator iterator() { - throw new UnsupportedOperationException(); - - } - - @Override - public Object[] toArray() { - throw new UnsupportedOperationException(); - - } - - @Override - public T[] toArray(T[] a) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean add(E e) { - save.add(e); - return true; - } - - @Override - public boolean remove(Object o) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean containsAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean addAll(Collection c) { - for (E e : c) - addAndSave(e); - return true; - } - - @Override - public boolean removeAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean retainAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public void clear() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean offer(E e) { - throw new UnsupportedOperationException(); - } - - @Override - public E remove() { - throw new UnsupportedOperationException(); - } - - @Override - public E poll() { - String path = paths.poll(); - E ret = SerializationUtils.readObject(new File(path)); - File item = new File(path); - item.delete(); - return ret; - } - - @Override - public E element() { - throw new UnsupportedOperationException(); - } - - @Override - public E peek() { - throw new UnsupportedOperationException(); - } - - private void addAndSave(E e) { - File path = new File(dir, UUID.randomUUID().toString()); - SerializationUtils.saveObject(e, path); - paths.add(path.getAbsolutePath()); - } -} diff --git a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/Dl4jReflection.java b/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/Dl4jReflection.java deleted file mode 100755 index b620a8c21..000000000 --- a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/Dl4jReflection.java +++ /dev/null @@ -1,132 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.util; - - -import java.lang.reflect.Constructor; -import java.lang.reflect.Field; -import java.lang.reflect.Modifier; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Properties; - -/** - * @author Adam Gibson - */ -public class Dl4jReflection { - private Dl4jReflection() {} - - /** - * Gets the empty constructor from a class - * @param clazz the class to get the constructor from - * @return the empty constructor for the class - */ - public static Constructor getEmptyConstructor(Class clazz) { - Constructor c = clazz.getDeclaredConstructors()[0]; - for (int i = 0; i < clazz.getDeclaredConstructors().length; i++) { - if (clazz.getDeclaredConstructors()[i].getParameterTypes().length < 1) { - c = clazz.getDeclaredConstructors()[i]; - break; - } - } - - return c; - } - - - public static Field[] getAllFields(Class clazz) { - // Keep backing up the inheritance hierarchy. - Class targetClass = clazz; - List fields = new ArrayList<>(); - - do { - fields.addAll(Arrays.asList(targetClass.getDeclaredFields())); - targetClass = targetClass.getSuperclass(); - } while (targetClass != null && targetClass != Object.class); - - return fields.toArray(new Field[fields.size()]); - } - - /** - * Sets the properties of the given object - * @param obj the object o set - * @param props the properties to set - */ - public static void setProperties(Object obj, Properties props) throws Exception { - for (Field field : obj.getClass().getDeclaredFields()) { - field.setAccessible(true); - if (props.containsKey(field.getName())) { - set(field, obj, props.getProperty(field.getName())); - } - - } - } - - /* sets a field with a fairly basic strategy */ - private static void set(Field field, Object obj, String value) throws Exception { - Class clazz = field.getType(); - field.setAccessible(true); - if (clazz.equals(Double.class) || clazz.equals(double.class)) { - double val = Double.valueOf(value); - field.set(obj, val); - } else if (clazz.equals(String.class)) { - field.set(obj, value); - } else if (clazz.equals(Integer.class) || clazz.equals(int.class)) { - int val = Integer.parseInt(value); - field.set(obj, val); - } else if (clazz.equals(Float.class) || clazz.equals(float.class)) { - float f = Float.parseFloat(value); - field.set(obj, f); - } - } - - - /** - * Get fields as properties - * @param obj the object to get fields for - * @param clazzes the classes to use for reflection and properties. - * T - * @return the fields as properties - */ - public static Properties getFieldsAsProperties(Object obj, Class[] clazzes) throws Exception { - Properties props = new Properties(); - for (Field field : obj.getClass().getDeclaredFields()) { - if (Modifier.isStatic(field.getModifiers())) - continue; - field.setAccessible(true); - Class type = field.getType(); - if (clazzes == null || contains(type, clazzes)) { - Object val = field.get(obj); - if (val != null) - props.put(field.getName(), val.toString()); - - } - } - - return props; - } - - - private static boolean contains(Class test, Class[] arr) { - for (Class c : arr) - if (c.equals(test)) - return true; - return false; - } - -} diff --git a/deeplearning4j/deeplearning4j-zoo/pom.xml b/deeplearning4j/deeplearning4j-zoo/pom.xml index 976d7500b..bec71ec04 100644 --- a/deeplearning4j/deeplearning4j-zoo/pom.xml +++ b/deeplearning4j/deeplearning4j-zoo/pom.xml @@ -71,6 +71,13 @@ ${deeplearning4j.version} test + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/BaseDL4JTest.java deleted file mode 100644 index 8c2b9bb07..000000000 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/BaseDL4JTest.java +++ /dev/null @@ -1,141 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.zoo; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java index 9349f05a5..c92f5acdc 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.zoo; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.transferlearning.TransferLearning; @@ -31,6 +32,11 @@ import java.io.File; public class MiscTests extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Test public void testTransferVGG() throws Exception { //https://github.com/deeplearning4j/deeplearning4j/issues/5167 diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java index 1b9853b58..b45afe47a 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java @@ -17,6 +17,7 @@ package org.deeplearning4j.zoo; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.zoo.model.LeNet; @@ -47,6 +48,11 @@ import static org.junit.Assert.assertEquals; @Slf4j public class TestDownload extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return isIntegrationTests() ? 480000L : 60000L; + } + @ClassRule public static TemporaryFolder testDir = new TemporaryFolder(); private static File f; @@ -66,12 +72,20 @@ public class TestDownload extends BaseDL4JTest { public void testDownloadAllModels() throws Exception { // iterate through each available model - ZooModel[] models = new ZooModel[]{ - LeNet.builder().build(), - SimpleCNN.builder().build(), - UNet.builder().build(), - NASNet.builder().build() - }; + ZooModel[] models; + + if(isIntegrationTests()){ + models = new ZooModel[]{ + LeNet.builder().build(), + SimpleCNN.builder().build(), + UNet.builder().build(), + NASNet.builder().build()}; + } else { + models = new ZooModel[]{ + LeNet.builder().build(), + SimpleCNN.builder().build()}; + } + for (int i = 0; i < models.length; i++) { diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java index 877c7699a..9106bede3 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java @@ -19,6 +19,7 @@ package org.deeplearning4j.zoo; import lombok.extern.slf4j.Slf4j; import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.transform.ColorConversionTransform; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.objdetect.DetectedObject; import org.deeplearning4j.nn.layers.objdetect.YoloUtils; @@ -56,6 +57,11 @@ import static org.junit.Assert.assertTrue; @Slf4j public class TestImageNet extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java index 4c8e66191..d70137775 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java @@ -17,6 +17,7 @@ package org.deeplearning4j.zoo; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.BenchmarkDataSetIterator; import org.deeplearning4j.nn.api.Model; diff --git a/deeplearning4j/dl4j-integration-tests/pom.xml b/deeplearning4j/dl4j-integration-tests/pom.xml index 27461c923..43e6bfa60 100644 --- a/deeplearning4j/dl4j-integration-tests/pom.xml +++ b/deeplearning4j/dl4j-integration-tests/pom.xml @@ -68,6 +68,13 @@ test + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/BaseDL4JTest.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/BaseDL4JTest.java deleted file mode 100644 index f6294b9cf..000000000 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/BaseDL4JTest.java +++ /dev/null @@ -1,141 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.integration; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTests.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTests.java index 4e1cb95f5..8e2ceef79 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTests.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.integration; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.integration.testcases.*; import org.junit.AfterClass; import org.junit.Ignore; diff --git a/deeplearning4j/dl4j-perf/pom.xml b/deeplearning4j/dl4j-perf/pom.xml deleted file mode 100644 index 239eead6b..000000000 --- a/deeplearning4j/dl4j-perf/pom.xml +++ /dev/null @@ -1,128 +0,0 @@ - - - - - - - deeplearning4j-parent - org.deeplearning4j - 1.0.0-SNAPSHOT - - 4.0.0 - - dl4j-perf - - dl4j-perf - - - UTF-8 - 1.7 - 1.7 - - - - - org.slf4j - slf4j-api - - - com.github.oshi - oshi-json - ${oshi.version} - - - org.deeplearning4j - deeplearning4j-nn - ${project.version} - - - com.github.oshi - oshi-core - ${oshi.version} - - - - junit - junit - - - - org.projectlombok - lombok - ${lombok.version} - provided - - - - ch.qos.logback - logback-classic - test - - - org.deeplearning4j - deeplearning4j-datasets - ${project.version} - test - - - - - - - - maven-clean-plugin - 3.0.0 - - - - maven-resources-plugin - 3.0.2 - - - maven-compiler-plugin - 3.7.0 - - - maven-surefire-plugin - 2.20.1 - - - maven-jar-plugin - 3.0.2 - - - maven-install-plugin - 2.5.2 - - - maven-deploy-plugin - 2.8.2 - - - - - - - - test-nd4j-native - - - test-nd4j-cuda-10.2 - - - diff --git a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/BaseDL4JTest.java b/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/BaseDL4JTest.java deleted file mode 100644 index 9ead56a7e..000000000 --- a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/BaseDL4JTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.perf.listener; - -import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -@Slf4j -public class BaseDL4JTest { - - @Rule - public TestName name = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - /** - * Override this to set the profiling mode for the tests defined in the child class - */ - public OpExecutioner.ProfilingMode getProfilingMode(){ - return OpExecutioner.ProfilingMode.SCOPE_PANIC; - } - - /** - * Override this to set the datatype of the tests defined in the child class - */ - public DataType getDataType(){ - return DataType.DOUBLE; - } - - public DataType getDefaultFPDataType(){ - return getDataType(); - } - - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index f9b1eecce..bcf633e3a 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -140,11 +140,10 @@ deeplearning4j-nearestneighbors-parent deeplearning4j-data deeplearning4j-manifold - deeplearning4j-util - dl4j-perf dl4j-integration-tests deeplearning4j-common deeplearning4j-remote + deeplearning4j-common-tests diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index d8b0439b4..c82b0b217 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.6) +cmake_minimum_required(VERSION 3.15) project(libnd4j) set(CMAKE_VERBOSE_MAKEFILE OFF) option(NATIVE "Optimize for build machine (might not work on others)" OFF) @@ -7,6 +7,21 @@ set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH}) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF) option(BUILD_TESTS "Build tests" OFF) +option(FLATBUFFERS_BUILD_FLATC "Enable the build of the flatbuffers compiler" OFF) +set(FLATBUFFERS_BUILD_FLATC "OFF" CACHE STRING "Hack to disable flatc build" FORCE) + +set(CMAKE_CXX_STANDARD 11) +if (CUDA_BLAS) + enable_language(CUDA) + set(CMAKE_CUDA_STANDARD 11) + + set(DEFAULT_ENGINE "samediff::ENGINE_CUDA") +else() + set(DEFAULT_ENGINE "samediff::ENGINE_CPU") +endif() + +# MSVC runtime lib can be either "MultiThreaded" or "MultiThreadedDLL", /MT and /MD respectively +set(MSVC_RT_LIB "MultiThreadedDLL") set(X86_BUILD false) @@ -17,23 +32,23 @@ endif() # -fsanitize=address # -fsanitize=leak if (ANDROID_BUILD) - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D_RELEASE=true") - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D_RELEASE=true") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else") elseif (APPLE) - set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true") - set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true") + set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true") + set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true") elseif(WIN32) set(X86_BUILD true) if (CUDA_BLAS) set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true") set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc") else() - set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true") - set(CMAKE_CXX_FLAGS_DEBUG " -g -O2 -fPIC -std=c++11 -fmax-errors=2") + set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -fmax-errors=2 -D_RELEASE=true") + set(CMAKE_CXX_FLAGS_DEBUG " -g -O2 -fPIC -fmax-errors=2") endif() else() - set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true") - set(CMAKE_CXX_FLAGS_DEBUG " -g -O0 -fPIC -std=c++11 -fmax-errors=2") + set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -fmax-errors=2 -D_RELEASE=true") + set(CMAKE_CXX_FLAGS_DEBUG " -g -O0 -fPIC -fmax-errors=2") if (CPU_BLAS) set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") @@ -49,6 +64,7 @@ if(NATIVE) ENDIF() endif() + if(NOT CUDA_BLAS) # we need this definition to avoid global memory use within mkldnn add_definitions(-DDNNL_ENABLE_CONCURRENT_EXEC=true) @@ -117,36 +133,70 @@ if(NOT CUDA_BLAS) include_directories(${CPUF_SOURCE_DIR}/include) set(CPU_FEATURES cpu_features) endif() +endif() - # new mkl-dnn entry - if (${HELPERS_mkldnn}) - message("Going to pull & build mkldnn") - set(HAVE_MKLDNN 1) - set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE) - configure_file(./CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt) - execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . - RESULT_VARIABLE result - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download ) - if(result) - message(FATAL_ERROR "CMake step for mkldnn failed: ${result}") - endif() - execute_process(COMMAND ${CMAKE_COMMAND} --build . - RESULT_VARIABLE result - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download ) - if(result) - message(FATAL_ERROR "Build step for mkldnn failed: ${result}") - endif() +# new mkl-dnn entry +if (${HELPERS_mkldnn}) + message("Going to pull & build mkldnn") + set(HAVE_MKLDNN 1) + set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE) - add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src - ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build - EXCLUDE_FROM_ALL) + configure_file(./CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt) + execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download ) + if(result) + message(FATAL_ERROR "CMake step for mkldnn failed: ${result}") + endif() + execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download ) + if(result) + message(FATAL_ERROR "Build step for mkldnn failed: ${result}") + endif() - set(mkldnn_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build) - set(mkldnn_EXT_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src) - set(MKLDNN_PATH "${mkldnn_SOURCE_DIR}") - include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_EXT_DIR}/include ${mkldnn_SOURCE_DIR}) - set(MKLDNN dnnl) + add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src + ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build + EXCLUDE_FROM_ALL) + + set(mkldnn_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build) + set(mkldnn_EXT_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src) + set(MKLDNN_PATH "${mkldnn_SOURCE_DIR}") + include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_EXT_DIR}/include ${mkldnn_SOURCE_DIR}) + set(MKLDNN dnnl) +endif() + + +if (${HELPERS_cudnn}) + if (NOT CUDA_BLAS) + message(FATAL_ERROR "Can't build cuDNN on non-CUDA platform") + endif() + + set(CUDNN_ROOT_DIR "" CACHE PATH "Folder contains NVIDIA cuDNN") + + # FIXME: we don't want static library in master + SET(CUDNN_LIBNAME "cudnn") + SET(CULIBOS_LIBNAME "culibos") + + find_path(CUDNN_INCLUDE_DIR cudnn.h + HINTS ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES cuda/include include) + + find_library(CUDNN_LIBRARY ${CUDNN_LIBNAME} + HINTS ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) + + find_library(CULIBOS_LIBRARY ${CULIBOS_LIBNAME} + HINTS ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) + + + if (CUDNN_LIBRARY) + set(HAVE_CUDNN true) + set(CUDNN ${CUDNN_LIBRARY} ${CULIBOS_LIBRARY}) + else() + message(FATAL_ERROR "Unable to find cuDNN") endif() endif() @@ -174,6 +224,8 @@ set(HAVE_FLATBUFFERS 1) set(FLATBUFFERS_PATH ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src) include_directories(${FLATBUFFERS_PATH}/include) + + configure_file(include/config.h.in include/config.h) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) diff --git a/libnd4j/CMakeLists.txt.in b/libnd4j/CMakeLists.txt.in index 33946a014..8e8741c86 100644 --- a/libnd4j/CMakeLists.txt.in +++ b/libnd4j/CMakeLists.txt.in @@ -9,6 +9,7 @@ ExternalProject_Add(flatbuffers SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-build" CONFIGURE_COMMAND "" + CMAKE_ARGS "-DFLATBUFFERS_BUILD_FLATC=OFF" BUILD_COMMAND "" INSTALL_COMMAND "" TEST_COMMAND "" diff --git a/libnd4j/CMakeLists.txt.mkldnn.in b/libnd4j/CMakeLists.txt.mkldnn.in index 2b773abea..3069d9efe 100644 --- a/libnd4j/CMakeLists.txt.mkldnn.in +++ b/libnd4j/CMakeLists.txt.mkldnn.in @@ -5,7 +5,7 @@ project(mkldnn-download NONE) include(ExternalProject) ExternalProject_Add(mkldnn GIT_REPOSITORY https://github.com/intel/mkl-dnn.git - GIT_TAG v1.1.1 + GIT_TAG v1.1.2 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" CONFIGURE_COMMAND "" diff --git a/libnd4j/CMakeSettings.json b/libnd4j/CMakeSettings.json index 2bb5bddbc..afda69260 100644 --- a/libnd4j/CMakeSettings.json +++ b/libnd4j/CMakeSettings.json @@ -12,6 +12,21 @@ "cmakeCommandArgs": " -DCUDA_BLAS=true -DLIBND4J_NAME=nd4jcuda -DMSVC_DEV=true -DCOMPUTE=61 -DBUILD_TESTS=true", "buildCommandArgs": "-v", "ctestCommandArgs": "" + }, + { + "name": "WSL-GCC-Debug", + "generator": "Unix Makefiles", + "configurationType": "Debug", + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeExecutable": "/usr/bin/cmake", + "cmakeCommandArgs": "-DLIBND4J_ALL_OPS=true -DCMAKE_BUILD_TYPE=Debug -DCPU_BLAS=true -DLIBND4J_NAME=nd4jcpu -DBUILD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug -DOPENBLAS_PATH=/usr/lib/openblas-base/ -DEXTENSION=avx2 ", + "buildCommandArgs": "-j 4", + "ctestCommandArgs": "", + "inheritEnvironments": [ "linux_x64" ], + "wslPath": "${defaultWSLPath}", + "addressSanitizerRuntimeFlags": "detect_leaks=0", + "variables": [] } ] } \ No newline at end of file diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index c86bdc13a..c1c5de399 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -101,16 +101,17 @@ ELSE() endif() ENDIF() -if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND X86_BUILD) + # apple clang but not ios-arm + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") # using Clang SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") - elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel") # using Intel C++ SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -O3 -fp-model fast") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") # using Visual Studio C++ - set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") # using GCC @@ -130,122 +131,72 @@ if(!CUDA_BLAS) endif() endif() +#if MKLDNN is enabled - we're building mkldnn-powered helpers +if (HAVE_MKLDNN) + file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h) +endif() + if(CUDA_BLAS) message("Build cublas") find_package(CUDA) add_definitions(-D__CUDABLAS__=true) if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") set (CMAKE_CXX_FLAGS "") - elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - if ("${CMAKE_C_COMPILER_VERSION}" VERSION_GREATER 4.9 AND "$ENV{TRICK_NVCC}" STREQUAL "YES" AND CUDA_VERSION VERSION_LESS "8.0") - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__GNUC__=4 -D__GNUC_MINOR__=9 -D_FORCE_INLINES -D_MWAITXINTRIN_H_INCLUDED") - set (CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -std=c++11 -Dnullptr=NULL") - message("TRICKING CUDA INTO SUPPORTING GCC > 4.9 YOU ARE PROCEEDING AT YOUR OWN RISK") - endif() - endif() - - # we want OpenMP to be available for hybrid operations, at least for GCC - if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - find_package(OpenMP) - if (OPENMP_FOUND) - set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") - endif() endif() if (CUDA_FOUND) message("CUDA include directory: ${CUDA_INCLUDE_DIRS}") include_directories(${CUDA_INCLUDE_DIRS}) message("CUDA found!") - set( CUDA_ARCHITECTURE_MINIMUM "3.0" CACHE STRING "Minimum required CUDA compute capability" ) - SET(CUDA_VERBOSE_BUILD OFF) - SET(CUDA_SEPARABLE_COMPILATION OFF) - #set(CUDA_COMPUTE_CAPABILITY "61") - set(CUDA_COMPUTE_CAPABILITY "35") - # make NVCC more verbose to prevent timeouts on CI servers - #list(APPEND CUDA_NVCC_FLAGS -v) + if ("${EXPERIMENTAL}" STREQUAL "yes") message("Experimental mode ENABLED") - list(APPEND CUDA_NVCC_FLAGS -D__ND4J_EXPERIMENTAL__=true) - set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__ND4J_EXPERIMENTAL__=true") - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__ND4J_EXPERIMENTAL__=true") - set (EXPM " -D__ND4J_EXPERIMENTAL__=true") + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -D__ND4J_EXPERIMENTAL__=true") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__ND4J_EXPERIMENTAL__=true") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__ND4J_EXPERIMENTAL__=true") + set(EXPM " -D__ND4J_EXPERIMENTAL__=true") endif() - if (CMAKE_BUILD_TYPE STREQUAL "Release") - if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10 - if ("${COMPUTE}" STREQUAL "all") - if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60) - else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) - endif() + + # the only difference for debug mode here is host/device debug symbols + set(CMAKE_CUDA_FLAGS_DEBUG " -G -g") + + # we need -fPIC on Linux/GCC + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + message("Enabling fPIC...") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fPIC") + endif() + + if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10 + if ("${COMPUTE}" STREQUAL "all") + if (APPLE) + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60") else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) - endif() - elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9 - if ("${COMPUTE}" STREQUAL "all") - if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60) - else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) - endif() - else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) - endif() - elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8.0 - if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60) - else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70") endif() else() - if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 ) - else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) - endif() + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE}") + endif() + elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9 + if ("${COMPUTE}" STREQUAL "all") + if (APPLE) + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60") + else() + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70") + endif() + else() + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}") + endif() + elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8.0 + if ("${COMPUTE}" STREQUAL "all") + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60") + else() + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda --Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}") endif() - else() - # debug only - if (LINUX) - SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--export-dynamic -rdynamic") - SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} --export-dynamic") - endif() - - if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 9 - message("CUDA 10 Debug build") - if ("${COMPUTE}" STREQUAL "all") - if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) - elseif() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70) - endif() - else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) - endif() - elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9 - if ("${COMPUTE}" STREQUAL "all") - if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) - elseif() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70) - endif() - else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) - endif() - elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8 - if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) - else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) - endif() + if ("${COMPUTE}" STREQUAL "all") + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_75 ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52") else() - if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53) - else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) - endif() + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_75 ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}") endif() endif() @@ -264,30 +215,44 @@ if(CUDA_BLAS) file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/impl/*.cpp ../include/loops/*.h) file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu) + if (HAVE_CUDNN) + message("cuDNN included") + file(GLOB_RECURSE CUSTOMOPS_CUDNN_SOURCES false ../include/ops/declarable/platform/cudnn/*.cu) + endif() - CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA} + add_library(nd4jobj OBJECT cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} - ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) + ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}) + + add_library(${LIBND4J_NAME} SHARED $) + + if (WIN32) + message("MSVC runtime for library: ${MSVC_RT_LIB}") + endif() + + # static library is built only if we're going to build tests, skip otherwise + if (BUILD_TESTS) + add_library(${LIBND4J_NAME}static STATIC $) + set_property(TARGET ${LIBND4J_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$:Debug>") + install(TARGETS ${LIBND4J_NAME}static DESTINATION .) + endif() + + # on windows we want to make sure we use MT or MD, but since we use it in one lib, we must use it everywhere to avoid conflicts + set_property(TARGET nd4jobj PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$:Debug>") + set_property(TARGET ${LIBND4J_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$:Debug>") if(WIN32) message("CUDA on Windows: enabling /EHsc") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14") - SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc /bigobj /std:c++14") endif() - target_link_libraries(${LIBND4J_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY}) + target_link_libraries(${LIBND4J_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda) install(TARGETS ${LIBND4J_NAME} DESTINATION .) - - add_custom_command( - TARGET ${LIBND4J_NAME} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy - $ - ${PROJECT_BINARY_DIR}/../../tests_cpu/) endif(CUDA_FOUND) elseif(CPU_BLAS) @@ -311,12 +276,6 @@ elseif(CPU_BLAS) file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h) file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h) - - #if MKLDNN is enabled - we're building mkldnn-powered helpers - if (HAVE_MKLDNN) - file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h) - endif() - if (X86_BUILD) # we disable platform optimizations for certains files for linux/macos set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic") @@ -329,21 +288,19 @@ elseif(CPU_BLAS) cpu/NativeOpExecutioner.cpp cpu/NDArray.cpp cpu/NDArrayFactory.cpp ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h Environment.cpp Environment.h ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} - ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} + ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) if(IOS) add_library(${LIBND4J_NAME} STATIC $) else() - add_library(${LIBND4J_NAME}static STATIC $) + # static library is built only if we're going to build tests, skip otherwise + if (BUILD_TESTS) + add_library(${LIBND4J_NAME}static STATIC $) + endif() + add_library(${LIBND4J_NAME} SHARED $) endif() - #if(WIN32) - # message("CPU on Windows: enabling /EHsc") - # SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14") - # SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc /bigobj /std:c++14") - #endif() - # we're including {MKLDNN} here in case of building from sources. in future that'll replace {MKLDNN_LIBRARIES}. same applies to BLAS if (NOT BLAS_LIBRARIES) set(BLAS_LIBRARIES "") @@ -374,7 +331,6 @@ elseif(CPU_BLAS) SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic") endif() - #install(TARGETS mySharedLib DESTINATION /some/full/path) install(TARGETS ${LIBND4J_NAME} DESTINATION .) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cpu) endif() diff --git a/libnd4j/blas/Environment.cpp b/libnd4j/blas/Environment.cpp index de0ac925b..3b9502534 100644 --- a/libnd4j/blas/Environment.cpp +++ b/libnd4j/blas/Environment.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #ifdef _OPENMP @@ -61,6 +62,7 @@ namespace nd4j { std::string omp(omp_threads); int val = std::stoi(omp); _maxThreads.store(val); + _maxMasterThreads.store(val); } catch (std::invalid_argument &e) { // just do nothing } catch (std::out_of_range &e) { @@ -100,6 +102,11 @@ namespace nd4j { } } + if (_maxMasterThreads.load() > _maxThreads.load()) { + nd4j_printf("Warning! MAX_MASTER_THREADS > MAX_THREADS, tuning them down to match each other\n",""); + _maxMasterThreads.store(_maxThreads.load()); + } + /** * If this env var is defined - we'll disallow use of platform-specific helpers (mkldnn, cudnn, etc) */ @@ -285,11 +292,19 @@ namespace nd4j { } void Environment::setMaxThreads(int max) { + // FIXME: not possible at this moment, since maxThreads is limited by number of threads in pool. however we can allocate more threads if we want //_maxThreads.store(max); } void Environment::setMaxMasterThreads(int max) { - //_maxMasterThreads = max; + if (max > maxThreads()) { + max = maxThreads(); + } + + if (max < 1) + return; + + _maxMasterThreads = max; } bool Environment::precisionBoostAllowed() { @@ -328,6 +343,38 @@ namespace nd4j { _allowHelpers.store(reallyAllow); } + void Environment::setGroupLimit(int group, Nd4jLong numBytes) { + nd4j::memory::MemoryCounter::getInstance()->setGroupLimit((nd4j::memory::MemoryType) group, numBytes); + } + + void Environment::setDeviceLimit(int deviceId, Nd4jLong numBytes) { + nd4j::memory::MemoryCounter::getInstance()->setDeviceLimit(deviceId, numBytes); + } + + Nd4jLong Environment::getGroupLimit(int group) { + return nd4j::memory::MemoryCounter::getInstance()->groupLimit((nd4j::memory::MemoryType) group); + } + + Nd4jLong Environment::getDeviceLimit(int deviceId) { + return nd4j::memory::MemoryCounter::getInstance()->deviceLimit(deviceId); + } + + Nd4jLong Environment::getGroupCounter(int group) { + return nd4j::memory::MemoryCounter::getInstance()->allocatedGroup((nd4j::memory::MemoryType) group); + } + + Nd4jLong Environment::getDeviceCounter(int deviceId) { + return nd4j::memory::MemoryCounter::getInstance()->allocatedDevice(deviceId); + } + + uint64_t Environment::maxPrimaryMemory() { + return _maxTotalPrimaryMemory.load(); + } + + uint64_t Environment::maxSpecialMemory() { + return _maxTotalSpecialMemory.load(); + } + nd4j::Environment *nd4j::Environment::_instance = 0; } diff --git a/libnd4j/blas/Environment.h b/libnd4j/blas/Environment.h index 54982471f..5bef3f1e4 100644 --- a/libnd4j/blas/Environment.h +++ b/libnd4j/blas/Environment.h @@ -27,6 +27,7 @@ #include #include #include +#include namespace nd4j{ class ND4J_EXPORT Environment { @@ -97,10 +98,30 @@ namespace nd4j{ int maxMasterThreads(); void setMaxMasterThreads(int max); + /* + * Legacy memory limits API, still used in new API as simplified version + */ void setMaxPrimaryMemory(uint64_t maxBytes); void setMaxSpecialyMemory(uint64_t maxBytes); void setMaxDeviceMemory(uint64_t maxBytes); + uint64_t maxPrimaryMemory(); + uint64_t maxSpecialMemory(); + //////////////////////// + + /* + * Methods for memory limits/counters + */ + void setGroupLimit(int group, Nd4jLong numBytes); + void setDeviceLimit(int deviceId, Nd4jLong numBytes); + + Nd4jLong getGroupLimit(int group); + Nd4jLong getDeviceLimit(int deviceId); + + Nd4jLong getGroupCounter(int group); + Nd4jLong getDeviceCounter(int deviceId); + //////////////////////// + bool isUseMKLDNN() { return _useMKLDNN.load(); } void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN.store(useMKLDNN); } diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index d89ef8c72..671f72a57 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -42,30 +42,60 @@ #include #include #include +#include +#include +#include namespace nd4j { + template ::value>::type> + ND4J_EXPORT NDArray operator+(const NDArray& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator+(NDArray&& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator+(const T& scalar, const NDArray& arr); + template ::value>::type> + ND4J_EXPORT NDArray operator+(const T& scalar, NDArray&& arr); - ND4J_EXPORT NDArray operator-(const float&, const NDArray&); - ND4J_EXPORT NDArray operator-(const float16&, const NDArray&); - ND4J_EXPORT NDArray operator-(const double&, const NDArray&); - ND4J_EXPORT NDArray operator-(const int&, const NDArray&); + template ::value>::type> + ND4J_EXPORT NDArray operator-(const NDArray& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator-(NDArray&& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator-(const T& scalar, const NDArray& arr); + template ::value>::type> + ND4J_EXPORT NDArray operator-(const T& scalar, NDArray&& arr); + + template ::value>::type> + ND4J_EXPORT NDArray operator*(const NDArray& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator*(NDArray&& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator*(const T& scalar, const NDArray& arr); + template ::value>::type> + ND4J_EXPORT NDArray operator*(const T& scalar, NDArray&& arr); + + template ::value>::type> + ND4J_EXPORT NDArray operator/(const NDArray& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator/(NDArray&& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator/(const T& scalar, const NDArray& arr); + template ::value>::type> + ND4J_EXPORT NDArray operator/(const T& scalar, NDArray&& arr); + + template ::type>::value && std::is_same::type>::value>::type> + ND4J_EXPORT NDArray operator+(T1&& arr1, T2&& arr2); + template ::type>::value && std::is_same::type>::value>::type> + ND4J_EXPORT NDArray operator-(T1&& arr1, T2&& arr2); + template ::type>::value && std::is_same::type>::value>::type> + ND4J_EXPORT NDArray operator*(T1&& arr1, T2&& arr2); + template ::type>::value && std::is_same::type>::value>::type> + ND4J_EXPORT NDArray operator/(T1&& arr1, T2&& arr2); - ND4J_EXPORT NDArray operator+(const float&, const NDArray&); - ND4J_EXPORT NDArray operator+(const float16&, const NDArray&); - ND4J_EXPORT NDArray operator+(const double&, const NDArray&); - ND4J_EXPORT NDArray operator+(const int&, const NDArray&); - ND4J_EXPORT NDArray operator*(const float&, const NDArray&); - ND4J_EXPORT NDArray operator*(const float16&, const NDArray&); - ND4J_EXPORT NDArray operator*(const double&, const NDArray&); - ND4J_EXPORT NDArray operator*(const int&, const NDArray&); - ND4J_EXPORT NDArray operator/(const float&, const NDArray&); - ND4J_EXPORT NDArray operator/(const float16&, const NDArray&); - ND4J_EXPORT NDArray operator/(const double&, const NDArray&); - ND4J_EXPORT NDArray operator/(const int&, const NDArray&); ND4J_EXPORT NDArray mmul(const NDArray&, const NDArray&); @@ -274,14 +304,11 @@ namespace nd4j { * @param writeList * @param readList */ - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list - static void registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList); - static void prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables = false); - - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list - static void registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList); - static void preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables = false); + static void registerSpecialUse(const std::vector& writeList, const std::vector& readList); + static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList); + static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); /** * This method returns buffer pointer offset by given number of elements, wrt own data type @@ -323,7 +350,7 @@ namespace nd4j { * axis - axis along which to repeat elements * repeats - number of repetitions */ - NDArray* repeat(const int axis, const std::vector& repeats) const; + NDArray repeat(const int axis, const std::vector& repeats) const; /** * This method fills this array with zeros @@ -336,15 +363,7 @@ namespace nd4j { * @param array * @return */ - static NDArray quantize(NDArray &array); - - /** - * This method returns quantized copy of given array - * - * @param array - * @return - */ - static NDArray* quantize(NDArray *array); + static NDArray quantize(const NDArray &array); /** * fill target array by repeating current array @@ -356,19 +375,16 @@ namespace nd4j { /** * creates array which points on certain sub-range of this array, sub-range is defined by given indices */ - NDArray* subarray(IndicesList& indices) const; - NDArray* subarray(const std::initializer_list& idx) const; - NDArray* subarray(const Intervals& idx) const; + NDArray subarray(IndicesList& indices) const; + NDArray subarray(const std::initializer_list& idx) const; + NDArray subarray(const Intervals& idx) const; /** * cast array elements to given dtype */ - template - NDArray* cast(); + NDArray cast(DataType dtype) const; - NDArray* cast(DataType dtype) const; - - void cast(NDArray* target, DataType dtype); + void cast(NDArray& target, DataType dtype); /** * returns _context @@ -455,16 +471,22 @@ namespace nd4j { /** * permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array */ - NDArray permute(const std::initializer_list& dimensions) const; - NDArray permute(const std::vector& dimensions) const; - NDArray permute(const int* dimensions, const int rank) const; + NDArray permute(const std::initializer_list& dimensions) const &; + NDArray permute(const std::vector& dimensions) const &; + NDArray permute(const int* dimensions, const int rank) const &; + NDArray permute(const std::initializer_list& dimensions) &&; + NDArray permute(const std::vector& dimensions) &&; + NDArray permute(const int* dimensions, const int rank) &&; void permute(const int* dimensions, const int rank, NDArray& target) const; void permute(const std::vector& dimensions, NDArray& target) const; - NDArray permute(const std::initializer_list& dimensions) const; - NDArray permute(const std::vector& dimensions) const; - NDArray permute(const Nd4jLong* dimensions, const int rank) const; + NDArray permute(const std::initializer_list& dimensions) const &; + NDArray permute(const std::vector& dimensions) const &; + NDArray permute(const Nd4jLong* dimensions, const int rank) const &; + NDArray permute(const std::initializer_list& dimensions) &&; + NDArray permute(const std::vector& dimensions) &&; + NDArray permute(const Nd4jLong* dimensions, const int rank) &&; void permute(const Nd4jLong* dimensions, const int rank, NDArray& target) const; void permute(const std::vector& dimensions, NDArray& target) const; @@ -522,24 +544,13 @@ namespace nd4j { /** * this method assigns given value to all elements in array */ - void assign(const double value, bool allowParallelism = true); - void assign(const float value, bool allowParallelism = true); - void assign(const float16 value, bool allowParallelism = true); - void assign(const bfloat16& value, bool allowParallelism = true); - void assign(const Nd4jLong value, bool allowParallelism = true); - void assign(const int value, bool allowParallelism = true); - void assign(const int16_t value, bool allowParallelism = true); - void assign(const uint8_t value, bool allowParallelism = true); - void assign(const uint16_t value, bool allowParallelism = true); - void assign(const uint32_t value, bool allowParallelism = true); - void assign(const uint64_t value, bool allowParallelism = true); - void assign(const int8_t value, bool allowParallelism = true); - void assign(const bool value, bool allowParallelism = true); + template ::value>::type> + void assign(const T& value, bool allowParallelism = true); /** * returns new copy of this array, optionally in different order */ - NDArray *dup(const char newOrder = 'a') const; + NDArray dup(const char newOrder = 'a') const; /** * returns sum of all elements of array @@ -566,21 +577,17 @@ namespace nd4j { * keepDims - if true then put unities in place of reduced dimensions */ - NDArray* reduceAlongDimension(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDims(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDims(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDims(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDims(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; /** * method reduces array by excluding its shapes along dimensions present in given dimensions vector @@ -589,10 +596,10 @@ namespace nd4j { * keepDims - if true then put unities in place of reduced dimensions * extras - extra parameters */ - void reduceAlongDimension(nd4j::reduce::FloatOps op, NDArray* target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(nd4j::reduce::SameOps op, NDArray* target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(nd4j::reduce::BoolOps op, NDArray* target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(nd4j::reduce::LongOps op, NDArray* target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; + void reduceAlongDimension(nd4j::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; + void reduceAlongDimension(nd4j::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; + void reduceAlongDimension(nd4j::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; + void reduceAlongDimension(nd4j::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; /** * return variance of array elements set @@ -631,20 +638,24 @@ namespace nd4j { void makeBothActual() const { syncToDevice(); syncToHost(); } - void applyTransform(nd4j::transform::FloatOps op, NDArray *target = nullptr, ExtraArguments *extraParams = nullptr); - void applyTransform(nd4j::transform::SameOps op, NDArray *target = nullptr, ExtraArguments *extraParams = nullptr); - void applyTransform(nd4j::transform::AnyOps op, NDArray *target = nullptr, ExtraArguments *extraParams = nullptr); - void applyTransform(nd4j::transform::BoolOps op, NDArray *target = nullptr, ExtraArguments *extraParams = nullptr); - void applyTransform(nd4j::transform::StrictOps op, NDArray *target = nullptr, ExtraArguments *extraParams = nullptr); + void applyTransform(nd4j::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams = nullptr); + void applyTransform(nd4j::transform::SameOps op, NDArray& target, ExtraArguments *extraParams = nullptr); + void applyTransform(nd4j::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams = nullptr); + void applyTransform(nd4j::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams = nullptr); + void applyTransform(nd4j::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams = nullptr); /** * apply OpName transformation to this array and store result in new array to be returned * extraParams - extra parameters for operation */ - NDArray transform(nd4j::transform::FloatOps op, void *extraParams = nullptr) const; - NDArray transform(nd4j::transform::SameOps op, void *extraParams = nullptr) const; - NDArray transform(nd4j::transform::BoolOps op, void *extraParams = nullptr) const; - NDArray transform(nd4j::transform::StrictOps op, void *extraParams = nullptr) const; + NDArray transform(nd4j::transform::FloatOps op, void *extraParams = nullptr) const &; + NDArray transform(nd4j::transform::SameOps op, void *extraParams = nullptr) const &; + NDArray transform(nd4j::transform::BoolOps op, void *extraParams = nullptr) const &; + NDArray transform(nd4j::transform::StrictOps op, void *extraParams = nullptr) const &; + NDArray transform(nd4j::transform::FloatOps op, void *extraParams = nullptr) &&; + NDArray transform(nd4j::transform::SameOps op, void *extraParams = nullptr) &&; + NDArray transform(nd4j::transform::BoolOps op, void *extraParams = nullptr) &&; + NDArray transform(nd4j::transform::StrictOps op, void *extraParams = nullptr) &&; /** * apply pairwise OpName transformation based on "this" and "other" arras elements, store result in this array @@ -659,11 +670,11 @@ namespace nd4j { * target - where to store result * extraParams - extra parameters for operation */ - void applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const; + void applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const; - void applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const; + void applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const; - void applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const; + void applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray& other, NDArray&target, ExtraArguments *extraParams = nullptr) const; /** * apply operation which requires broadcasting, broadcast a smaller array (tad) along bigger one (this) @@ -672,23 +683,23 @@ namespace nd4j { * target - where to store result * extraParams - extra parameters for operation */ - void applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list dimensions, const NDArray* tad, NDArray* target = nullptr, ExtraArguments* extraArgs = nullptr); + void applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list dimensions, const NDArray& tad, NDArray& target, ExtraArguments* extraArgs = nullptr); - void applyBroadcast(nd4j::broadcast::Ops op, const std::vector &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr); + void applyBroadcast(nd4j::broadcast::Ops op, const std::vector &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr); - void applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr); - - void applyBroadcast(nd4j::broadcast::IntOps op, const std::vector &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr); + void applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr); + void applyBroadcast(nd4j::broadcast::IntOps op, const std::vector &dimensions, const NDArray& tad, NDArray &target, ExtraArguments *extraArgs = nullptr); /** * apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting * other - input array * extraParams - extra parameters for operation */ - NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) const; - - NDArray* applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* other, ExtraArguments *extraArgs = nullptr) const; + NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) const &; + NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) const &; + NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) &&; + NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) &&; /** * apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting @@ -697,11 +708,11 @@ namespace nd4j { * checkTargetShape - if true check whether target shape is suitable for broadcasting * extraParams - extra parameters for operation */ - void applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; + void applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; - void applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; + void applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; - void applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; + void applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; /** @@ -711,13 +722,13 @@ namespace nd4j { * extraParams - extra parameters for operation */ template - void applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray* target = nullptr, ExtraArguments *extraParams = nullptr); + void applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr); template - void applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + void applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; template - void applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + void applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; /** * apply a scalar operation to an array @@ -725,27 +736,27 @@ namespace nd4j { * target - where to store result * extraParams - extra parameters for operation */ - void applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArray* target = nullptr, ExtraArguments *extraParams = nullptr); + void applyScalarArr(nd4j::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr); - void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; - void applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + void applyScalarArr(nd4j::scalar::IntOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; #if defined(__CUDABLAS__) //&& defined(BUILD_TESTS) template - FORCEINLINE void applyLambda(Lambda func, NDArray* target = nullptr); + FORCEINLINE void applyLambda(Lambda func, NDArray& target); template - FORCEINLINE void applyPairwiseLambda(const NDArray* other, Lambda func, NDArray* target = nullptr); + FORCEINLINE void applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target); template - FORCEINLINE void applyIndexedLambda(Lambda func, NDArray* target = nullptr); + FORCEINLINE void applyIndexedLambda(Lambda func, NDArray& target); template - FORCEINLINE void applyIndexedPairwiseLambda(NDArray* other, Lambda func, NDArray* target = nullptr); + FORCEINLINE void applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target); template - FORCEINLINE void applyTriplewiseLambda(NDArray* second, NDArray *third, Lambda func, NDArray* target = nullptr); + FORCEINLINE void applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target); #else /** @@ -754,7 +765,7 @@ namespace nd4j { * target - where to store result */ template - void applyLambda(const std::function& func, NDArray* target = nullptr); + void applyLambda(const std::function& func, NDArray& target); /** * apply pairwise operation "func" to an array @@ -763,16 +774,16 @@ namespace nd4j { * target - where to store result */ template - void applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target = nullptr); + void applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); template - void applyIndexedLambda(const std::function& func, NDArray* target = nullptr); + void applyIndexedLambda(const std::function& func, NDArray& target); template - void applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target = nullptr); + void applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); template - void applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target = nullptr); + void applyTriplewiseLambda(NDArray& second, NDArray& third, const std::function& func, NDArray& target); #endif /** @@ -780,7 +791,7 @@ namespace nd4j { * dimensions - vector of dimensions to reduce along * extraArgs - extra parameters for operation */ - NDArray* applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; + NDArray applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; /** * reduces dimensions in array relying on index operation OpName @@ -788,14 +799,14 @@ namespace nd4j { * dimensions - vector of dimensions to reduce along * extraArgs - extra parameters for operation */ - void applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; + void applyIndexReduce(nd4j::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; /** * apply reduce3 operation OpName to this and other array, return result in new output array * other - input array * extraArgs - extra parameters for operation */ - NDArray* applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, const ExtraArguments* extraParams = nullptr) const; + NDArray applyReduce3(nd4j::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams = nullptr) const; /** * apply reduce3 operation OpName to this and other array, return result in new output array @@ -803,7 +814,7 @@ namespace nd4j { * dimensions - vector of dimensions to reduce along (tads not axis) * extraArgs - extra parameters for operation */ - NDArray* applyAllReduce3(nd4j::reduce3::Ops op, const NDArray* other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; + NDArray applyAllReduce3(nd4j::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; /** * apply reduce3 (exec) operation OpName to this and other array, return result in new output array @@ -811,30 +822,26 @@ namespace nd4j { * dimensions - vector of dimensions to reduce along (same as reduceAlongDimension) * extraArgs - extra parameters for operation */ - NDArray* applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; - + NDArray applyReduce3(nd4j::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; /** * returns variance along given dimensions * biasCorrected - if true bias correction will be applied * dimensions - vector of dimensions to calculate variance along */ - NDArray* varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const; - NDArray* varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const; + NDArray varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const; + NDArray varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const; - NDArray varianceAlongDims(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const; - NDArray varianceAlongDims(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const; - - void varianceAlongDimension(nd4j::variance::Ops op, NDArray* target, const bool biasCorrected, const std::vector& dimensions) const; - - void varianceAlongDimension(nd4j::variance::Ops op, NDArray* target, const bool biasCorrected, const std::initializer_list& dimensions) const; + void varianceAlongDimension(nd4j::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const; + void varianceAlongDimension(nd4j::variance::Ops op, NDArray& target, const bool biasCorrected, const std::initializer_list& dimensions) const; #endif /** * apply transpose operation to the copy of this array, that is this array remains unaffected */ - NDArray transpose() const; + NDArray transpose() const &; + NDArray transpose() &&; /** * perform transpose operation and store result in target, this array remains unaffected @@ -852,8 +859,8 @@ namespace nd4j { * index - the number of array to be returned among set of possible arrays * dimensions - array of dimensions to point on */ - NDArray* tensorAlongDimension(Nd4jLong index, const std::initializer_list& dimensions) const; - NDArray* tensorAlongDimension(Nd4jLong index, const std::vector& dimensions) const; + NDArray tensorAlongDimension(Nd4jLong index, const std::initializer_list& dimensions) const; + NDArray tensorAlongDimension(Nd4jLong index, const std::vector& dimensions) const; /** * returns the number of arrays pointing on specified dimension(s) @@ -874,54 +881,54 @@ namespace nd4j { * add given row vector to all rows of this array * row - row vector to add */ - void addiRowVector(const NDArray *row); + void addiRowVector(const NDArray& row); /** * add given row vector to all rows of this array, store result in target * row - row vector to add * target - where to store result */ - void addRowVector(const NDArray *row, NDArray* target) const; + void addRowVector(const NDArray& row, NDArray& target) const; /** * subtract given row vector from all rows of this array, store result in target * row - row vector to subtract * target - where to store result */ - void subRowVector(const NDArray *row, NDArray* target) const; + void subRowVector(const NDArray& row, NDArray& target) const; /** * multiply all rows of this array on given row vector, store result in target * row - row vector to multiply on * target - where to store result */ - void mulRowVector(const NDArray *row, NDArray* target) const; + void mulRowVector(const NDArray &row, NDArray& target) const; /** * divide all rows of this array on given row vector, store result in target * row - row vector to divide on * target - where to store result */ - void divRowVector(const NDArray *row, NDArray* target) const; + void divRowVector(const NDArray &row, NDArray& target) const; /** * add given column vector to all columns of this array, store result in target * column - column vector to add * target - where to store result */ - void addColumnVector(const NDArray *column, NDArray* target) const; + void addColumnVector(const NDArray &column, NDArray& target) const; /** * add given column vector to all columns of this array, this array becomes affected (in-place operation) * column - column vector to add */ - void addiColumnVector(const NDArray *column); + void addiColumnVector(const NDArray &column); /** * multiply all columns of this array on given column vector, this array becomes affected (in-place operation) * column - column vector to multiply on */ - void muliColumnVector(const NDArray *column); + void muliColumnVector(const NDArray &column); /** * returns number of bytes used by _buffer & _shapeInfo @@ -934,6 +941,7 @@ namespace nd4j { template std::vector getBufferAsVector(); std::vector getShapeAsVector() const; + std::vector getShapeAsVectorInt() const; std::vector getShapeInfoAsVector(); std::vector getShapeInfoAsFlatVector(); std::vector getShapeAsFlatVector(); @@ -958,7 +966,8 @@ namespace nd4j { * * if permute have been applied before or there are weird strides, then new buffer is allocated for new array */ - NDArray reshape(const char order, const std::vector& shape) const; + NDArray reshape(const char order, const std::vector& shape) const &; + NDArray reshape(const char order, const std::vector& shape) &&; /** * calculate strides and set given order @@ -991,12 +1000,6 @@ namespace nd4j { */ void tile(NDArray& target) const; - /** - * returns an array which is result of broadcasting of this and other arrays - * other - input array - */ - NDArray* broadcast(const NDArray& other); - /** * check whether array is identity matrix */ @@ -1007,7 +1010,6 @@ namespace nd4j { */ bool isUnitary(); - /** * operator returns subarray with buffer pointing at this->_buffer with offset defined by given intervals * idx - intervals of indexes which define the subarrays to point on, idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * this->rankOf()) @@ -1038,27 +1040,6 @@ namespace nd4j { */ void getSubArrShapeAndOffsets(const std::vector& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape = false) const; - /** - * addition operator: array + other - * other - input array to add - */ - NDArray operator+(const NDArray& other) const; - - /** - * addition operator: array + scalar - * scalar - input scalar to add - */ - template - NDArray operator+(const T& scalar) const; - - /** - * friend functions which implement addition operator: scalar + array - * scalar - input scalar to add - */ - //template - //friend NDArray nd4j::operator+(const T scalar, const NDArray& arr); - - /** * addition unary operator array += other * other - input array to add @@ -1077,42 +1058,11 @@ namespace nd4j { template void operator-=(const T other); - /** - * subtraction operator: array - other - * other - input array to subtract - */ - NDArray operator-(const NDArray& other) const; - - /** - * subtraction operator: array - scalar - * scalar - input scalar to subtract - */ - template - NDArray operator-(const T& scalar) const; - /** * negative operator, it changes sign of all array elements on opposite */ - NDArray operator-() const; - - /** - * friend functions which implement subtraction operator: scalar - array - * scalar - input scalar to subtract - */ - //friend NDArray nd4j::operator-(const float scalar, const NDArray& arr); - - /** - * pairwise multiplication operator: array * other - * other - input array to multiply on - */ - NDArray operator*(const NDArray& other) const; - - /** - * multiplication operator: array * scalar - * scalar - input scalar to multiply on - */ - template - NDArray operator*(const T& scalar) const; + NDArray operator-() const &; + NDArray operator-() &&; /** * pairwise multiplication unary operator array *= other @@ -1127,19 +1077,6 @@ namespace nd4j { template void operator*=(const T scalar); - /** - * pairwise division operator: array / other - * other - input array to divide on - */ - NDArray operator/(const NDArray& other) const; - - /** - * division operator: array / scalar - * scalar - input scalar to divide each array element on - */ - template - NDArray operator/(const T& scalar) const; - /** * pairwise division unary operator: array /= other * other - input array to divide on @@ -1180,7 +1117,7 @@ namespace nd4j { * return vector with buffer which points on corresponding diagonal elements of array * type - means of vector to be returned: column ('c') or row ('r') */ - NDArray* diagonal(const char type ) const; + NDArray diagonal(const char type ) const; /** * fill target matrix with given value in one or two directions from main diagonal: @@ -1194,7 +1131,7 @@ namespace nd4j { * target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2) */ template - void fillAsTriangular(const float value, int lower, int upper, const char direction = 'b', NDArray* target = nullptr); + void fillAsTriangular(const float value, int lower, int upper, NDArray& target, const char direction = 'b'); /** * change an array by repeating it the number of times in order to acquire new shape equal to the input shape @@ -1203,15 +1140,15 @@ namespace nd4j { * target - optional argument, if target != nullptr the resulting array will be placed in target, in opposite case tile operation is done in place */ NDArray tileToShape(const Nd4jLong* shapeInfo); - void tileToShape(const std::vector& shape, NDArray* target = nullptr); + void tileToShape(const std::vector& shape, NDArray& target); #ifndef __JAVACPP_HACK__ - void tileToShape(const std::initializer_list& shape, NDArray* target = nullptr); + void tileToShape(const std::initializer_list& shape, NDArray& target); #endif template - NDArray* asT() const; + NDArray asT() const; - NDArray* asT(DataType dtype) const; + NDArray asT(DataType dtype) const; void linspace(const double start); @@ -1223,15 +1160,13 @@ namespace nd4j { */ double getTrace() const; - ResultSet* multipleTensorsAlongDimension(const std::vector& indices, const std::vector& dimensions) const; + ResultSet multipleTensorsAlongDimension(const std::vector& indices, const std::vector& dimensions) const; - ResultSet* allTensorsAlongDimension(const std::initializer_list& dimensions) const; + ResultSet allTensorsAlongDimension(const std::initializer_list& dimensions) const; - ResultSet* allTensorsAlongDimension(const std::vector& dimensions) const; + ResultSet allTensorsAlongDimension(const std::vector& dimensions) const; - //ResultSet allTensorsAlongDims(const std::vector& dimensions) const; - - ResultSet* allExamples()const ; + ResultSet allExamples()const ; /** * set _shapeInfo @@ -1356,7 +1291,7 @@ namespace nd4j { /** * returns true if these two NDArrays have same rank, dimensions, strides, ews and order */ - FORCEINLINE bool isSameShapeStrict(const NDArray *other) const; + FORCEINLINE bool isSameShapeStrict(const NDArray& other) const; /** * returns true if buffer && shapeInfo were defined (non nullptr) @@ -1439,11 +1374,6 @@ namespace nd4j { template void pIdx(const Nd4jLong* indices, const T value); - /** - * creates array which points on certain sub-range of this array, sub-range is defined by given indices - */ - NDArray* subarray(IndicesList& indices, std::vector& strides) const; - /** * returns true if array is 2D */ @@ -1512,64 +1442,9 @@ namespace nd4j { */ bool isS() const; - /** - * inline accessing operator for matrix, i - absolute index - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i) const; - - /** - * inline modifying operator for matrix, i - absolute index - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i); - - /** - * inline accessing operator for 2D array, i - row, j - column - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i, const Nd4jLong j) const; - - /** - * inline modifying operator for 2D array, i - row, j - column - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i, const Nd4jLong j); - - /** - * inline accessing operator for 3D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; - - /** - * inline modifying operator for 3D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k); - - /** - * inline modifying operator for 4D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w); - - /** - * inline accessing operator for 4D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) const; - - /** - * inline modifying operator for ND array - * idx - array with corresponding indexes, for example {2,10,0,5,...,8}, number of indexes should be equal to array rank - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong* idx); - - /** - * inline accessing operator for ND array - * idx - array with corresponding indexes, for example {2,10,0,5,...,8}, number of indexes should be equal to array rank - */ - //FORCEINLINE NDArray operator()(const Nd4jLong* idx) const; - - - template std::vector asVectorT(); - FORCEINLINE bool isAttached(); NDArray* detach(); @@ -1585,394 +1460,201 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// ////////////////////////////////////////////////////////////////////////// - bool NDArray::isAttached() { - return this->_context->getWorkspace() != nullptr; - } +bool NDArray::isAttached() { + return this->_context->getWorkspace() != nullptr; +} - template - FORCEINLINE R NDArray::templatedGet(void *buffer, Nd4jLong index) const { - auto b = reinterpret_cast(buffer); - auto v = static_cast(b[index]); - return v; - } +template +FORCEINLINE R NDArray::templatedGet(void *buffer, Nd4jLong index) const { + auto b = reinterpret_cast(buffer); + auto v = static_cast(b[index]); + return v; +} - ////////////////////////////////////////////////////////////////////////// - void NDArray::setShapeInfo(Nd4jLong *shapeInfo) { - auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); - _shapeInfo = buffer.primaryAsT(); - _shapeInfoD = buffer.specialAsT(); +////////////////////////////////////////////////////////////////////////// +void NDArray::setShapeInfo(Nd4jLong *shapeInfo) { + auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); + _shapeInfo = buffer.primaryAsT(); + _shapeInfoD = buffer.specialAsT(); - if (shapeInfo != nullptr) { - _dataType = ArrayOptions::dataType(_shapeInfo); - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - } - else { - _dataType = nd4j::DataType::INHERIT; + if (shapeInfo != nullptr) { + _dataType = ArrayOptions::dataType(_shapeInfo); + if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) _length = 0; - } + else + _length = shape::length(_shapeInfo); } + else { + _dataType = nd4j::DataType::INHERIT; + _length = 0; + } +} - ////////////////////////////////////////////////////////////////////////// - void NDArray::setShapeInfo(Nd4jLong *shapeInfo, const nd4j::DataType dtype) { - auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); - _shapeInfo = buffer.primaryAsT(); - _shapeInfoD = buffer.specialAsT(); +////////////////////////////////////////////////////////////////////////// +void NDArray::setShapeInfo(Nd4jLong *shapeInfo, const nd4j::DataType dtype) { + auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); + _shapeInfo = buffer.primaryAsT(); + _shapeInfoD = buffer.specialAsT(); - if (shapeInfo != nullptr) { - _dataType = dtype; - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - } - else { - _dataType = nd4j::DataType::INHERIT; + if (shapeInfo != nullptr) { + _dataType = dtype; + if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) _length = 0; - } + else + _length = shape::length(_shapeInfo); } - - ////////////////////////////////////////////////////////////////////////// - char NDArray::ordering() const { - return shape::order(_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isView() const { - return _isView; - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong* NDArray::shapeOf() const { - return shape::shapeOf(_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong* NDArray::stridesOf() const { - return shape::stride(_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - int NDArray::rankOf() const { - return shape::rank(_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong NDArray::lengthOf() const { - return _length; - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong NDArray::rows() const { - if (this->rankOf() == 1) - return 1; - - if (this->rankOf() > 2) - throw std::runtime_error("Array with rank > 2 can't have rows"); - - return shapeOf()[0]; - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong NDArray::columns() const { - if (this->rankOf() == 1) - return this->lengthOf(); - - if (this->rankOf() > 2) - throw std::runtime_error("Array with rank > 2 can't have columns"); - - return shapeOf()[1]; - } - - ////////////////////////////////////////////////////////////////////////// - - size_t NDArray::sizeOfT() const { - return DataTypeUtils::sizeOfElement(_dataType); - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong NDArray::ews() const { - if (this->isEmpty() || this->rankOf() == 0) - return 1; - - return shape::elementWiseStride(_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::nonNull() const { - if (isEmpty()) - return true; - - if(!Environment::getInstance()->isCPU()) - return getDataBuffer()->special() != nullptr && getSpecialShapeInfo() != nullptr; - - return getDataBuffer()->primary() != nullptr && getShapeInfo() != nullptr; - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isMatrix() const { - if (isEmpty()) - return false; - - return 0 != shape::isMatrix(this->_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isVector() const { - if (isEmpty()) - return false; - if (rankOf() == 1) - return true; - return !isScalar() && shape::isVector(this->_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isColumnVector() const { - if (isEmpty()) - return false; - - return !isScalar() && shape::isColumnVector(this->_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isRowVector() const { - if (isEmpty()) - return false; - - // 1D edge case - if (shape::rank(this->_shapeInfo) == 1) - return true; - - return !isScalar() && shape::isRowVector(this->_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isCommonVector(int& posOfNonUnityDim) const { - - return shape::isCommonVector(_shapeInfo, posOfNonUnityDim); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isScalar() const { - return 0 != shape::isScalar(this->_shapeInfo); + else { + _dataType = nd4j::DataType::INHERIT; + _length = 0; } +} ////////////////////////////////////////////////////////////////////////// -// accessing operator for matrix, i - absolute index -/* -NDArray NDArray::operator()(const Nd4jLong i) const { - - if (i >= shape::length(_shapeInfo)) - throw std::invalid_argument("NDArray::operator(i): input index is out of array length !"); - - auto ews = shape::elementWiseStride(_shapeInfo); - char order = ordering(); - - if(ews == 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else if(ews > 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * ews * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else { - Nd4jLong idx[MAX_RANK]; - shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } +char NDArray::ordering() const { + return shape::order(_shapeInfo); } -*/ -////////////////////////////////////////////////////////////////////////// -// modifying operator for matrix, i - absolute index -/* -NDArray& NDArray::operator()(const Nd4jLong i) { - if (i >= shape::length(_shapeInfo)) - throw std::invalid_argument("NDArray::operator(i): input index is out of array length !"); - - auto ews = shape::elementWiseStride(_shapeInfo); - auto order = ordering(); - - if(ews == 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - // FIXME: bad - return result; - } else if(ews > 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * ews * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else { - Nd4jLong idx[MAX_RANK]; - shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } -}*/ ////////////////////////////////////////////////////////////////////////// -// accessing operator for 2D matrix, i - row, j - column -/* -NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j) const { - - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); - - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - // TODO: do we really want a view here? - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; +bool NDArray::isView() const { + return _isView; } -*/ -////////////////////////////////////////////////////////////////////////// -// modifying operator for 2D matrix, i - row, j - column -/* -NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j) { - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); - - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - //FIXME: bad, will crash! - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -// accessing operator for 3D array, i - row, j - column -/* -NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { - - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || j >= shapeOf()[2]) - throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); - - Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; +Nd4jLong* NDArray::shapeOf() const { + return shape::shapeOf(_shapeInfo); } -*/ ////////////////////////////////////////////////////////////////////////// -// modifying operator for 3D array -/* -NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { - - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) - throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); - - Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - //FIXME: bad, will crash! - return result; +Nd4jLong* NDArray::stridesOf() const { + return shape::stride(_shapeInfo); } -*/ -/* -NDArray NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) const { - if (rankOf() != 4 || t >= shapeOf()[0] || u >= shapeOf()[1] || v >= shapeOf()[2] || w >= shapeOf()[3]) - throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); - - Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ -/* -NDArray& NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) { - - if (rankOf() != 4 || t >= shapeOf()[0] || u >= shapeOf()[1] || v >= shapeOf()[2] || w >= shapeOf()[3]) - throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); - - Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - // FIXME - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -/* -NDArray NDArray::operator()(const Nd4jLong* idx) const { - - for(int i = 0; i < rankOf(); ++i) - if (idx[i] >= sizeAt(i)) - throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; +int NDArray::rankOf() const { + return shape::rank(_shapeInfo); } -*/ + ////////////////////////////////////////////////////////////////////////// -/* -NDArray& NDArray::operator()(const Nd4jLong* idx) { - - for(int i = 0; i < rankOf(); ++i) - if (idx[i] >= sizeAt(i)) - throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - // FIXME - return result; +Nd4jLong NDArray::lengthOf() const { + return _length; } -*/ +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::rows() const { + if (this->rankOf() == 1) + return 1; - ////////////////////////////////////////////////////////////////////////// - Nd4jLong FORCEINLINE NDArray::memoryFootprint() { - Nd4jLong size = this->lengthOf() * this->sizeOfT(); - size += shape::shapeInfoByteLength(this->rankOf()); - return size; - } + if (this->rankOf() > 2) + throw std::runtime_error("Array with rank > 2 can't have rows"); - ////////////////////////////////////////////////////////////////////////// - // still the definition of inline function must be in header file - bool NDArray::isSameShape(const std::vector& shape) const{ - if (this->isScalar() && shape.size() == 1 && shape[0] == 0) - return true; - if (this->rankOf() != (int) shape.size()) - return false; - for (int e = 0; e < this->rankOf(); e++) { - if (this->shapeOf()[e] != shape.at(e) && shape.at(e) != -1) - return false; - } + return shapeOf()[0]; +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::columns() const { + if (this->rankOf() == 1) + return this->lengthOf(); + + if (this->rankOf() > 2) + throw std::runtime_error("Array with rank > 2 can't have columns"); + + return shapeOf()[1]; +} + +////////////////////////////////////////////////////////////////////////// + +size_t NDArray::sizeOfT() const { + return DataTypeUtils::sizeOfElement(_dataType); +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::ews() const { + if (this->isEmpty() || this->rankOf() == 0) + return 1; + + return shape::elementWiseStride(_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::nonNull() const { + if (isEmpty()) return true; + + if(!Environment::getInstance()->isCPU()) + return getDataBuffer()->special() != nullptr && getSpecialShapeInfo() != nullptr; + + return getDataBuffer()->primary() != nullptr && getShapeInfo() != nullptr; +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isMatrix() const { + if (isEmpty()) + return false; + + return 0 != shape::isMatrix(this->_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isVector() const { + if (isEmpty()) + return false; + if (rankOf() == 1) + return true; + return !isScalar() && shape::isVector(this->_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isColumnVector() const { + if (isEmpty()) + return false; + + return !isScalar() && shape::isColumnVector(this->_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isRowVector() const { + if (isEmpty()) + return false; + + // 1D edge case + if (shape::rank(this->_shapeInfo) == 1) + return true; + + return !isScalar() && shape::isRowVector(this->_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isCommonVector(int& posOfNonUnityDim) const { + + return shape::isCommonVector(_shapeInfo, posOfNonUnityDim); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isScalar() const { + return 0 != shape::isScalar(this->_shapeInfo); +} + + +////////////////////////////////////////////////////////////////////////// +Nd4jLong FORCEINLINE NDArray::memoryFootprint() { + Nd4jLong size = this->lengthOf() * this->sizeOfT(); + size += shape::shapeInfoByteLength(this->rankOf()); + return size; +} + +////////////////////////////////////////////////////////////////////////// +// still the definition of inline function must be in header file +bool NDArray::isSameShape(const std::vector& shape) const{ + if (this->isScalar() && shape.size() == 1 && shape[0] == 0) + return true; + if (this->rankOf() != (int) shape.size()) + return false; + for (int e = 0; e < this->rankOf(); e++) { + if (this->shapeOf()[e] != shape.at(e) && shape.at(e) != -1) + return false; } + return true; +} ////////////////////////////////////////////////////////////////////////// bool NDArray::isSameShape(const NDArray *other) const { @@ -2009,8 +1691,8 @@ bool NDArray::areSameShapeAndType(const NDArray& other) const { // returns true if these two NDArrays have same _shapeInfo // still the definition of inline function must be in header file -bool NDArray::isSameShapeStrict(const NDArray *other) const { - return shape::equalsStrict(_shapeInfo, other->_shapeInfo); +bool NDArray::isSameShapeStrict(const NDArray& other) const { + return shape::equalsStrict(_shapeInfo, other._shapeInfo); } ////////////////////////////////////////////////////////////////////////// @@ -2103,7 +1785,7 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { template T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) { - if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2), w >= sizeAt(3)) + if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3)) throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4 !"); if (DataTypeUtils::fromT() != _dataType) throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!"); diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 5adff5853..42b29cf78 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -35,21 +35,6 @@ ND4J_EXPORT utf8string NDArray::e(const Nd4jLong i) const; template <> ND4J_EXPORT std::string NDArray::e(const Nd4jLong i) const; -////////////////////////////////////////////////////////////////////////// -template -NDArray* NDArray::asT() const{ - - auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); - auto l = this->lengthOf(); - - NDArray::prepareSpecialUse({result}, {this}); - NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result->getBuffer(), result->getShapeInfo(), result->getSpecialBuffer(), result->getSpecialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({result}, {this}); - - return result; -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArray::asT, () const, LIBND4J_TYPES); - //////////////////////////////////////////////////////////////////////// // copy constructor NDArray::NDArray(const NDArray& other) { @@ -238,6 +223,8 @@ NDArray::NDArray(std::shared_ptr buffer, const ShapeDescriptor& desc setShapeInfo(descriptor); _buffer = buffer; + + _isView = offset > 0 || _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); } //////////////////////////////////////////////////////////////////////// @@ -303,6 +290,8 @@ NDArray::NDArray(std::shared_ptr buffer, const char order, const std setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape)); _buffer = buffer; + + _isView = _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); } //////////////////////////////////////////////////////////////////////// @@ -455,6 +444,16 @@ std::vector NDArray::getShapeAsVector() const { return vector; } +//////////////////////////////////////////////////////////////////////// +std::vector NDArray::getShapeAsVectorInt() const { + + std::vector vector(this->rankOf()); + for (int e = 0; e < this->rankOf(); e++) + vector[e] = static_cast(this->sizeAt(e)); + + return vector; +} + //////////////////////////////////////////////////////////////////////// std::vector NDArray::getShapeInfoAsFlatVector() { int magicNumber = shape::shapeInfoLength(this->rankOf()); @@ -499,9 +498,7 @@ std::vector NDArray::asByteVector() { if (this->isView()) { auto tmp = this->dup(this->ordering()); syncToHost(); - memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); - - delete tmp; + memcpy(result.data(), tmp.getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); } else { syncToHost(); memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); @@ -590,26 +587,78 @@ void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCop dataBuffer()->copyBufferFrom(*other.getDataBuffer(), sizeToCopyInBytes, offsetThis, offsetOther); } +//////////////////////////////////////////////////////////////////// +// This method assigns values of given NDArray to this one +void NDArray::assign(const NDArray& other, bool allowParallelism) { + + if (this == &other) + return; + + if (other.isEmpty()) { + if (!isEmpty()) { + ArrayOptions::setPropertyBit(shapeInfo(), ARRAY_EMPTY); + syncShape(); + _buffer = std::make_shared(); + _offset = 0; + } + return; + } + + if(isEmpty()) { + *this = other; + return; + } + + if (other.lengthOf() == 1) { + + if(lengthOf() == 1) { + NDArray::preparePrimaryUse({this}, {&other}); + BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&other}); + this->syncToDevice(); + } + else { + if (dataType() != other.dataType()) { + auto tmp = other.cast(dataType()); + NDArray::prepareSpecialUse({this}, {&tmp}); + NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp.getBuffer(), tmp.getShapeInfo(), tmp.getSpecialBuffer(), tmp.getSpecialShapeInfo(), nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {}); + } + else { + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {&other}); + } + } + } + else { + if (other.lengthOf() != lengthOf()) { + auto shapeThis = ShapeUtils::shapeAsString(this); + auto shapeThat = ShapeUtils::shapeAsString(&other); + nd4j_printf("Can't assign array: this shape %s; other shape: %s\n", shapeThis.c_str(), shapeThat.c_str()); + throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched"); + } + + // memcpy is allowed only for same order && same ews (being equal to 1) + if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1) + copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT()); + else { + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {&other}); + } + } +} ////////////////////////////////////////////////////////////////////////// // This method assigns values of given NDArray to this one, wrt order - void NDArray::assign(const NDArray *other, bool allowParallelism) { - assign(*other, allowParallelism); - } - -////////////////////////////////////////////////////////////////////////// -// This method assigns given value to all elements in this NDArray -void NDArray::assign(const double value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {&temp}); +void NDArray::assign(const NDArray *other, bool allowParallelism) { + assign(*other, allowParallelism); } ////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const float value, bool allowParallelism) { +template +void NDArray::assign(const T& value, bool allowParallelism) { // just fire scalar auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); @@ -617,116 +666,19 @@ void NDArray::assign(const float value, bool allowParallelism) { NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); NDArray::registerSpecialUse({this}, {&temp}); } - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const float16 value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const bfloat16& value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const Nd4jLong value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const int value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const int16_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const uint8_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const uint16_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const uint32_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const uint64_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const int8_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const bool value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} +template ND4J_EXPORT void NDArray::assign(const double& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const float& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const float16& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const bfloat16& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const Nd4jLong& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const int& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const int8_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const int16_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const uint8_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const uint16_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const uint32_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const uint64_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const bool& value, bool allowParallelism); ////////////////////////////////////////////////////////////////////////// NDArray* NDArray::detach() { @@ -841,32 +793,7 @@ void* NDArray::bufferWithOffset(Nd4jLong offset) const { ////////////////////////////////////////////////////////////////////////// // eventually method reduces array by excluding its shapes along axes present in dimensions vector -NDArray* NDArray::reduceAlongDimension(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - return new NDArray(reduceAlongDims(op, dimensions, keepDims, supportOldShapes)); -} - -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - return new NDArray(reduceAlongDims(op, dimensions, keepDims, supportOldShapes)); -} - -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - return new NDArray(reduceAlongDims(op, dimensions, keepDims, supportOldShapes)); -} - -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - return new NDArray(reduceAlongDims(op, dimensions, keepDims, supportOldShapes)); -} - -////////////////////////////////////////////////////////////////////////// -// eventually method reduces array by excluding its shapes along axes present in dimensions vector -NDArray NDArray::reduceAlongDims(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { std::vector copy(dimensions); @@ -874,13 +801,13 @@ NDArray NDArray::reduceAlongDims(nd4j::reduce::FloatOps op, const std::vectorreduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); return result; } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDims(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { std::vector copy(dimensions); @@ -888,13 +815,13 @@ NDArray NDArray::reduceAlongDims(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { std::vector copy(dimensions); @@ -902,13 +829,13 @@ NDArray NDArray::reduceAlongDims(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { std::vector copy(dimensions); @@ -916,29 +843,29 @@ NDArray NDArray::reduceAlongDims(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); } ////////////////////////////////////////////////////////////////////////// -NDArray *NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); } ////////////////////////////////////////////////////////////////////////// -NDArray *NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); } ////////////////////////////////////////////////////////////////////////// -NDArray *NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); } @@ -1082,11 +1009,6 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector& dimensions) cons return numTads; } -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::initializer_list& dimensions) const { - return tensorAlongDimension(index, std::vector(dimensions)); -} - ////////////////////////////////////////////////////////////////////////// void NDArray::printShapeInfo(const char * msg) const { //shape::printShapeInfo(_shapeInfo); @@ -1305,13 +1227,20 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void* NDArray::templatedPointerShift, ////////////////////////////////////////////////////////////////////////// // method makes copy of this array and applies to the copy transpose operation, this array remains unaffected -NDArray NDArray::transpose() const { +NDArray NDArray::transpose() const &{ NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset()); newArr.transposei(); return newArr; } +////////////////////////////////////////////////////////////////////////// +// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected +NDArray NDArray::transpose() && { + + this->transposei(); + return std::move(*this); +} //////////////////////////////////////////////////////////////////////// // method performs transpose operation based on this array and store result in target, this array remains unaffected @@ -1418,7 +1347,7 @@ Nd4jLong NDArray::argMax(std::initializer_list dimensions) { ////////////////////////////////////////////////////////////////////////// // create new array with corresponding order and shape, new array will point to the same _buffer as this array -NDArray NDArray::reshape(const char order, const std::vector& shape) const { +NDArray NDArray::reshape(const char order, const std::vector& shape) const & { NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset()); newArr.reshapei(order, shape); @@ -1426,6 +1355,13 @@ NDArray NDArray::reshape(const char order, const std::vector& shape) c return newArr; } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reshape(const char order, const std::vector& shape) && { + + this->reshapei(order, shape); + return std::move(*this); +} + ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. void NDArray::tilei(const std::vector& reps) { @@ -1490,7 +1426,7 @@ bool NDArray::permutei(const std::vector& dimensions) { } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const int* dimensions, const int rank) const { +NDArray NDArray::permute(const int* dimensions, const int rank) const & { // evaluate shapeInfo for output (permuted) array ret auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); @@ -1499,38 +1435,80 @@ NDArray NDArray::permute(const int* dimensions, const int rank) const { return ret; } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const int* dimensions, const int rank) && { + + this->permutei(dimensions, rank); + return std::move(*this); +} + ///////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const { +NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const &{ int tempDims[MAX_RANK]; shape::convertT(const_cast(dimensions), tempDims, rank); return permute(tempDims, rank); } +///////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && { + + this->permutei(dimensions, rank); + return std::move(*this); +} + ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) const { +NDArray NDArray::permute(const std::vector& dimensions) const &{ auto data = dimensions.data(); auto size = dimensions.size(); return permute(data, size); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) const { +NDArray NDArray::permute(const std::vector& dimensions) && { + + this->permutei(dimensions); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::vector& dimensions) const & { return permute(dimensions.data(), dimensions.size()); } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::vector& dimensions) && { + + this->permutei(dimensions); + return std::move(*this); +} ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) const { +NDArray NDArray::permute(const std::initializer_list& dimensions) const &{ + std::vector vec(dimensions); return permute(vec); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) const { +NDArray NDArray::permute(const std::initializer_list& dimensions) && { + + this->permutei(dimensions); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::initializer_list& dimensions) const & { std::vector vec(dimensions); return permute(vec); } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::initializer_list& dimensions) && { + + this->permutei(dimensions); + return std::move(*this); +} + ////////////////////////////////////////////////////////////////////////// void NDArray::permute(const int* dimensions, const int rank, NDArray& target) const { if (!nonNull() || !target.nonNull() || rank != rankOf() || rank != target.rankOf() ) @@ -1623,7 +1601,7 @@ T* NDArray::bufferAsT() const { BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , * NDArray::bufferAsT() const, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::subarray(IndicesList& idx) const { +NDArray NDArray::subarray(IndicesList& idx) const { const int idxSize = idx.size(); if (idxSize != this->rankOf()) @@ -1655,11 +1633,11 @@ NDArray* NDArray::subarray(IndicesList& idx) const { indexes[3 * d + 2] = idx.at(d)->getIndices().at(2); // stride } } - return new NDArray((*this)(indexes, true, true)); + return NDArray((*this)(indexes, true, true)); } //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::subarray(const std::initializer_list& idx) const { +NDArray NDArray::subarray(const std::initializer_list& idx) const { const int idxSize = idx.size(); if (idxSize != this->rankOf()) @@ -1698,11 +1676,11 @@ NDArray* NDArray::subarray(const std::initializer_list& idx) const { for (auto i: idx) delete i; - return new NDArray((*this)(indexes, true, true)); + return NDArray((*this)(indexes, true, true)); } //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::subarray(const Intervals& idx) const { +NDArray NDArray::subarray(const Intervals& idx) const { const int idxSize = idx.size(); if (idxSize != this->rankOf()) @@ -1723,390 +1701,47 @@ NDArray* NDArray::subarray(const Intervals& idx) const { } } - return new NDArray((*this)(indexes, true)); + return NDArray((*this)(indexes, true)); } +////////////////////////////////////////////////////////////////////////// +template +NDArray NDArray::asT() const{ + + auto result = isScalar() ? NDArray('c', {}, {0.}, DataTypeUtils::fromT(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); + auto l = this->lengthOf(); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.getSpecialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asT, () const, LIBND4J_TYPES); + //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::asT(DataType dtype) const { +NDArray NDArray::asT(DataType dtype) const { if (isS()) throw std::runtime_error("NDArray::asT: you can't use this method on String array!"); BUILD_SINGLE_SELECTOR(dtype, return asT, (), LIBND4J_TYPES); - return nullptr; + return NDArray(); } //////////////////////////////////////////////////////////////////////// -template -NDArray* NDArray::cast() { - if (isS()) - throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); - return this->asT(); -} - -//////////////////////////////////////////////////////////////////////// -NDArray* NDArray::cast(DataType dtype) const { +NDArray NDArray::cast(DataType dtype) const { if (isS()) throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); return this->asT(dtype); } //////////////////////////////////////////////////////////////////////// -void NDArray::cast(NDArray* target, DataType dtype) { +void NDArray::cast(NDArray& target, DataType dtype) { if (isS()) throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); // TODO: to be implemented properly - target->assign(this); -} - -//////////////////////////////////////////////////////////////////////// -// addition operator array + array -NDArray NDArray::operator+(const NDArray& other) const { - if (isS()) - throw std::runtime_error("NDArray::operator+: you can't use this method on String array!"); - - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (other.dataType() != DataType::BOOL) ) { - throw datatype_exception::build("NDArray::operator+: cannot add different types.", dataType(), other.dataType()); - } - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; - } - - return this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), other); -} - -//////////////////////////////////////////////////////////////////////// -// addition operator array + scalar -template -NDArray NDArray::operator+(const T& scalar) const { - if (isS()) - throw std::runtime_error("NDArray::operator+: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(dataType(), scalar, getContext()); - NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &tmp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray NDArray::operator+(const double& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const float& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const float16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const bfloat16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const Nd4jLong& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const int& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const int16_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const int8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const uint8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const bool& scalar) const; - -//////////////////////////////////////////////////////////////////////// -// subtraction operator array - scalar -template -NDArray NDArray::operator-(const T& scalar) const { - if (isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(dataType(), scalar, getContext()); - NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &tmp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray NDArray::operator-(const double& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const float& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const float16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const bfloat16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const Nd4jLong& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const int& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const int16_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const int8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const uint8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const bool& scalar) const; - -//////////////////////////////////////////////////////////////////////// -// multiplication operator array*scalar -template -NDArray NDArray::operator*(const T& scalar) const { - if (isS()) - throw std::runtime_error("NDArray::operator*: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(dataType(), scalar, getContext()); - NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &tmp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray NDArray::operator*(const double& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const float& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const float16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const bfloat16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const Nd4jLong& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const int& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const int16_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const int8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const uint8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const bool& scalar) const; - -//////////////////////////////////////////////////////////////////////// -// division operator array / scalar -template -NDArray NDArray::operator/(const T& scalar) const { - if (isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - - if(scalar == (T)0.) - throw std::runtime_error("NDArray::operator/ (division operator) : division by zero !"); - - auto tmp = NDArrayFactory::create(dataType(), scalar, getContext()); - NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &tmp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray NDArray::operator/(const double& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const float& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const float16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const bfloat16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const Nd4jLong& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const int& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const int16_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const int8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const uint8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const bool& scalar) const; - -//////////////////////////////////////////////////////////////////////// -// addition operator scalar + array -ND4J_EXPORT NDArray operator+(const float16& scalar, const NDArray& arr) { - return arr + scalar; -} -ND4J_EXPORT NDArray operator+(const bfloat16& scalar, const NDArray& arr) { - return arr + scalar; -} -ND4J_EXPORT NDArray operator+(const float& scalar, const NDArray& arr) { - return arr + scalar; -} -ND4J_EXPORT NDArray operator+(const double& scalar, const NDArray& arr) { - return arr + scalar; -} -ND4J_EXPORT NDArray operator+(const Nd4jLong& scalar, const NDArray& arr) { - return arr + scalar; -} -ND4J_EXPORT NDArray operator+(const int& scalar, const NDArray& arr) { - return arr + scalar; -} - -//////////////////////////////////////////////////////////////////////// -// addition operator scalar + array -ND4J_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr) { - return arr * scalar; -} -ND4J_EXPORT NDArray operator*(const bfloat16& scalar, const NDArray& arr) { - return arr * scalar; -} - -ND4J_EXPORT NDArray operator*(const float& scalar, const NDArray& arr) { - return arr * scalar; -} -ND4J_EXPORT NDArray operator*(const double& scalar, const NDArray& arr) { - return arr * scalar; -} -ND4J_EXPORT NDArray operator*(const Nd4jLong& scalar, const NDArray& arr) { - return arr * scalar; -} -ND4J_EXPORT NDArray operator*(const int& scalar, const NDArray& arr) { - return arr * scalar; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const float16& scalar, const NDArray & arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const bfloat16& scalar, const NDArray & arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const float& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const double& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const Nd4jLong& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const int& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator/(const bfloat16& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (arr.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide scalar by bool array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator/(const float16& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (arr.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide scalar by bool array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator/(const float& scalar, const NDArray & arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (arr.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide scalar by bool array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator/(const double& scalar, const NDArray & arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (arr.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide scalar by bool array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator/(const int& scalar, const NDArray & arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (arr.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide scalar by bool array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; + target.assign(this); } //////////////////////////////////////////////////////////////////////// @@ -2133,11 +1768,11 @@ void NDArray::operator+=(const NDArray& other) { throw std::invalid_argument("NDArray::operator+=: the shapes of this and other arrays are not suitable for broadcast operation !"); if(shape::equalsTypesAndShapesSoft(getShapeInfo(), bShape)) { - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &other, this, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), other, *this, false); } else { NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &other, &result, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), other, result, false); *this = std::move(result); // move assignment operator, zero cost copy } } @@ -2167,11 +1802,11 @@ void NDArray::operator-=(const NDArray& other) { throw std::invalid_argument("NDArray::operator-=: the shapes of this and other arrays are not suitable for broadcast operation !"); if(shape::equalsTypesAndShapesSoft(getShapeInfo(), bShape)) { - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), &other, this, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), other, *this, false); } else { NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), &other, &result, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), other, result, false); *this = std::move(result); // move assignment operator, zero cost copy } } @@ -2200,11 +1835,11 @@ void NDArray::operator*=(const NDArray& other) { throw std::invalid_argument("NDArray::operator*=: the shapes of this and other arrays are not suitable for broadcast operation !"); if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), &other, this, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), other, *this, false); } else { NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), &other, &result, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), other, result, false); *this = std::move(result); // move assignment operator, zero cost copy } } @@ -2237,15 +1872,16 @@ void NDArray::operator/=(const NDArray& other) { throw std::invalid_argument("NDArray::operator/=: the shapes of this and other arrays are not suitable for broadcast operation !"); if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), &other, this, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), other, *this, false); } else { NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), &other, &result, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), other, result, false); *this = std::move(result); // move assignment operator, zero cost copy } } } + //////////////////////////////////////////////////////////////////////// template void NDArray::operator+=(const T value) { @@ -2335,77 +1971,9 @@ template ND4J_EXPORT void NDArray::operator/=(const int8_t scalar); template ND4J_EXPORT void NDArray::operator/=(const uint8_t scalar); template ND4J_EXPORT void NDArray::operator/=(const bool scalar); -//////////////////////////////////////////////////////////////////////// -// subtraction operator array - array -NDArray NDArray::operator-(const NDArray& other) const { - if (isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw nd4j::datatype_exception::build("NDArray operator-: Cannot subtract different types", this->dataType(), other.dataType()); - - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; - } - - return this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), other); -} - -//////////////////////////////////////////////////////////////////////// -// multiplication operator array*array -NDArray NDArray::operator*(const NDArray& other) const { - if (isS()) - throw std::runtime_error("NDArray::operator*: you can't use this method on String array!"); - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw nd4j::datatype_exception::build("NDArray operator*: Cannot multiply different types", this->dataType(), other.dataType()); - - PointersManager pointersManager(getContext(), "operator *"); - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, this->getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &other}); - - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; - } - - return this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), other); -} - -//////////////////////////////////////////////////////////////////////// -// division operator array/array -NDArray NDArray::operator/(const NDArray& other) const { - if (isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (other.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide by bool array!"); - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType()) - throw nd4j::datatype_exception::build("NDArray operator/: Cannot divide different types", this->dataType(), other.dataType()); - - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; - } - - return this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), other); -} - //////////////////////////////////////////////////////////////////////// // negative operator, it makes all array elements = -elements -NDArray NDArray::operator-() const { +NDArray NDArray::operator-() const & { if (isS()) throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); @@ -2418,6 +1986,18 @@ NDArray NDArray::operator-() const { return result; } +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::operator-() && { + if (isS()) + throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), nd4j::transform::Neg, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + //////////////////////////////////////////////////////////////////////// // mathematical multiplication of two arrays NDArray mmul(const NDArray& left, const NDArray& right) { @@ -2430,9 +2010,9 @@ NDArray mmul(const NDArray& left, const NDArray& right) { } //////////////////////////////////////////////////////////////////////// -void NDArray::tileToShape(const std::vector& shape, NDArray* target) { - if(target != nullptr) { - this->tile(*target); +void NDArray::tileToShape(const std::vector& shape, NDArray& target) { + if(&target != this) { + this->tile(target); return; } @@ -2457,7 +2037,7 @@ void NDArray::tileToShape(const std::vector& shape, NDArray* target) { } //////////////////////////////////////////////////////////////////////// -void NDArray::tileToShape(const std::initializer_list& shape, NDArray* target) { +void NDArray::tileToShape(const std::initializer_list& shape, NDArray& target) { tileToShape(std::vector(shape), target); } @@ -2496,152 +2076,143 @@ double NDArray::getTrace() const { return sum; } - //////////////////////////////////////////////////////////////////////// -NDArray NDArray::quantize(NDArray &array) { - return *(quantize(&array)); -} +NDArray NDArray::quantize(const NDArray& array) { -//////////////////////////////////////////////////////////////////////// -NDArray* NDArray::quantize(NDArray *array) { - - if(array->isR()) + if(!array.isR()) throw std::invalid_argument("NDArray::quantize: type of array should be from real space!"); - auto ws = array->getContext()->getWorkspace(); + auto ws = array.getContext()->getWorkspace(); - Nd4jLong* shapeInfo = ShapeBuilders::copyShapeInfo(array->getShapeInfo(), true, ws); + Nd4jLong* shapeInfo = ShapeBuilders::copyShapeInfo(array.getShapeInfo(), true, ws); ArrayOptions::setPropertyBit(shapeInfo, ARRAY_QUANTIZED); - std::shared_ptr buffer = std::make_shared(TypeCast::estimateQuantizedSize(array->lengthOf()), ArrayOptions::dataType(shapeInfo), ws); + std::shared_ptr buffer = std::make_shared(TypeCast::estimateQuantizedSize(array.lengthOf()), ArrayOptions::dataType(shapeInfo), ws); - auto result = new NDArray(buffer, ShapeDescriptor(shapeInfo), array->getContext()); + NDArray result(buffer, ShapeDescriptor(shapeInfo), array.getContext()); return result; } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const { +void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { + if (isS()) throw std::runtime_error("NDArray::applyTrueBroadcast: you can't use this method on String array!"); - if(target == nullptr || other == nullptr) - throw std::runtime_error("NDArray::applyTrueBroadcast method: target or other = nullptr !"); - if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other->isB()) || (op.s == scalar::ReverseDivide && this->isB())) + + if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other.isB()) || (op.s == scalar::ReverseDivide && this->isB())) throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !"); - if (isEmpty() || other->isEmpty()) + if (isEmpty() || other.isEmpty()) return; if (lengthOf() == 1) { - target->assign(this); - target->applyPairwiseTransform(op.p, *other, extraArgs); + target.assign(this); + target.applyPairwiseTransform(op.p, other, extraArgs); return; } - if (other->lengthOf() == 1) { + if (other.lengthOf() == 1) { const_cast(this)->applyScalarArr(op.s, other, target, extraArgs); return; } if(checkTargetShape) { Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsTypesAndShapesSoft(target->getShapeInfo(), newShapeInfo)) + if(!shape::equalsTypesAndShapesSoft(target.getShapeInfo(), newShapeInfo)) throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !"); } - if(target->isSameShape(this) || target->isSameShape(other)) { - const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs); + if(target.isSameShape(this) || target.isSameShape(other)) { + const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs); return; } #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(dataType(), other->dataType(), target->dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR(dataType(), other.dataType(), target.dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES); #endif } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const { +void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { if (isS()) throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); - if(target == nullptr || other == nullptr) - throw std::runtime_error("NDArray::applyTrueBroadcast bool method: target or other = nullptr !"); - if (isEmpty() || other->isEmpty()) + if (isEmpty() || other.isEmpty()) return; if (lengthOf() == 1) { - NDArray temp(target->_shapeInfo, dataType(), false, getContext()); + NDArray temp(target._shapeInfo, dataType(), false, getContext()); temp.assign(this); temp.applyPairwiseTransform(op.p, other, target, extraArgs); return; } - if (other->lengthOf() == 1) { + if (other.lengthOf() == 1) { this->applyScalarArr(op.s, other, target, extraArgs); return; } if(checkTargetShape) { Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != DataType::BOOL) + if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != DataType::BOOL) throw std::runtime_error("NDArray::applyTrueBroadcast bool method: the shape or type of target array is wrong !"); - if(dataType() != other->dataType()) + if(dataType() != other.dataType()) throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !"); } - if(target->isSameShape(this) || target->isSameShape(other)) { - const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs); + if(target.isSameShape(this) || target.isSameShape(other)) { + const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs); return; } - BUILD_DOUBLE_SELECTOR(dataType(), target->dataType(), helpers::TrueBroadcastBoolHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), helpers::TrueBroadcastBoolHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES, BOOL_TYPES); } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const { +void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { + if (isS()) throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); - if(target == nullptr || other == nullptr) - throw std::runtime_error("NDArray::applyTrueBroadcast int method: target or other = nullptr !"); - if (isEmpty() || other->isEmpty()) + if (isEmpty() || other.isEmpty()) return; if (lengthOf() == 1) { - NDArray temp(target->_shapeInfo, dataType(), false, getContext()); + NDArray temp(target._shapeInfo, dataType(), false, getContext()); temp.assign(this); temp.applyPairwiseTransform(op.p, other, target, extraArgs); return; } - if (other->lengthOf() == 1) { + if (other.lengthOf() == 1) { this->applyScalarArr(op.s, other, target, extraArgs); return; } if(checkTargetShape) { Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != this->dataType()) + if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != this->dataType()) throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !"); - if(dataType() != other->dataType()) + if(dataType() != other.dataType()) throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !"); } - if(target->isSameShape(this) || target->isSameShape(other)) { - const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs); + if(target.isSameShape(this) || target.isSameShape(other)) { + const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs); return; } - BUILD_SINGLE_SELECTOR(dataType(), helpers::TrueBroadcastIntHelper, ::exec(op.b, *this, *other, *target), INTEGER_TYPES); + BUILD_SINGLE_SELECTOR(dataType(), helpers::TrueBroadcastIntHelper, ::exec(op.b, *this, other, target), INTEGER_TYPES); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const { +NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const & { if (isEmpty() || other.isEmpty()) { if (isEmpty()) return NDArray(*this); @@ -2654,19 +2225,100 @@ NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& o throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, &other, &result, false, extraArgs); + this->applyTrueBroadcast(op, other, result, false, extraArgs); return result; } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) { +NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) const & { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + + if(!shape::shapeEquals(newShapeInfo, other.getShapeInfo())) { + + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + this->applyTrueBroadcast(op, other, other, false, extraArgs); + return std::move(other); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) && { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + + if(!shape::shapeEquals(newShapeInfo, getShapeInfo())) { + + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + this->applyTrueBroadcast(op, other, *this, false, extraArgs); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) && { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + + const bool thisMove = shape::shapeEquals(newShapeInfo, getShapeInfo()); + const bool otherMove = shape::shapeEquals(newShapeInfo, other.getShapeInfo()); + + if(!thisMove && !otherMove) { + + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + if(thisMove) { + this->applyTrueBroadcast(op, other, *this, false, extraArgs); + return std::move(*this); + } + + // otherMove + this->applyTrueBroadcast(op, other, other, false, extraArgs); + return std::move(other); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { if (isS()) throw std::runtime_error("NDArray::applyBroadcast: you can't use this method on String array!"); - if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other->isB()) || (op == broadcast::ReverseDivide && this->isB())) + if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other.isB()) || (op == broadcast::ReverseDivide && this->isB())) throw std::runtime_error("NDArray::applyBroadcast: you can't divide by array!"); - if(isEmpty() || other->isEmpty()) { - if(!target->isEmpty()) + if(isEmpty() || other.isEmpty()) { + if(!target.isEmpty()) throw std::runtime_error("NDArray::applyBroadcast method: when some of input arrays (or both) is empty, target array must be empty as well !"); return; } @@ -2674,28 +2326,26 @@ void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector& di if (dimensions.size() == 0) return; - auto result = target == nullptr ? this : target; - - if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) { - NDArray::prepareSpecialUse({result}, {this, other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {this, other}); + if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); return; } NDArray *min(nullptr), *max(nullptr); - if((lengthOf() > other->lengthOf()) || (lengthOf() == other->lengthOf() && rankOf() >= other->rankOf())) { + if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) { max = this; - min = const_cast(other); + min = const_cast(&other); } else { - max = const_cast(other); + max = const_cast(&other); min = this; } - if(result->dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), other->getShapeInfo())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), other.getShapeInfo())) throw std::invalid_argument("NDArray::applyBroadcast method: wrong type of target array !"); - if(!result->isSameShape(max)) + if(!target.isSameShape(max)) throw std::invalid_argument("NDArray::applyBroadcast method: max and target arrays must have the same shape !"); std::vector copy(dimensions); @@ -2708,22 +2358,22 @@ void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector& di throw std::runtime_error("NDArray::applyBroadcast method: tad length mismatch !"); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(result->shapeInfo(), copy); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy); - NDArray::prepareSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&target}, {this, &other}); if(max == this) - NativeOpExecutioner::execBroadcast( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcast( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); else - NativeOpExecutioner::execInverseBroadcast(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); - registerSpecialUse({result}, {this, other}); + NativeOpExecutioner::execInverseBroadcast(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) { +void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { if (isS()) throw std::runtime_error("NDArray::applyBroadcast BoolOps: you can't use this method on String array!"); - if(isEmpty() || other->isEmpty()) { - if(!target->isEmpty()) + if(isEmpty() || other.isEmpty()) { + if(!target.isEmpty()) throw std::runtime_error("NDArray::applyBroadcast BoolOps: when some of input arrays (or both) is empty, target array must be empty as well !"); return; } @@ -2731,30 +2381,28 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector if (dimensions.size() == 0) return; - auto result = target == nullptr ? this : target; - - if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) { - NDArray::prepareSpecialUse({result}, {this, other}); - NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {this, other}); + if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); return; } NDArray *min(nullptr), *max(nullptr); - if((lengthOf() > other->lengthOf()) || (lengthOf() == other->lengthOf() && rankOf() >= other->rankOf())) { + if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) { max = this; - min = const_cast(other); + min = const_cast(&other); } else { - max = const_cast(other); + max = const_cast(&other); min = this; } - if(result->dataType() != DataType::BOOL) + if(target.dataType() != DataType::BOOL) throw std::invalid_argument("NDArray::applyBroadcast bool method: type of target array must be BOOL!"); - if(!result->isSameShape(max)) + if(!target.isSameShape(max)) throw std::invalid_argument("NDArray::applyBroadcast bool method: max and target arrays must have the same shape !"); - if(_dataType != other->_dataType) + if(_dataType != other._dataType) throw std::invalid_argument("NDArray::applyBroadcast bool method: this and other arrays must have the same type !"); std::vector copy(dimensions); @@ -2767,24 +2415,24 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector throw std::runtime_error("Tad length mismatch"); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(result->shapeInfo(), copy); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy); // TODO: eventually we want separate tads here - NDArray::prepareSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&target}, {this, &other}); if(max == this) - NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); else - NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); - registerSpecialUse({result}, {this, other}); + NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) { +void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { if (!isZ()) throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!"); - if(isEmpty() || other->isEmpty()) { - if(!target->isEmpty()) + if(isEmpty() || other.isEmpty()) { + if(!target.isEmpty()) throw std::runtime_error("NDArray::applyBroadcast IntOps: when some of input arrays (or both) is empty, target array must be empty as well !"); return; } @@ -2792,30 +2440,28 @@ void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector& if (dimensions.empty()) return; - auto result = target == nullptr ? this : target; - - if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) { - NDArray::prepareSpecialUse({result}, {this, other}); - NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {this, other}); + if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); return; } NDArray *min(nullptr), *max(nullptr); - if((lengthOf() > other->lengthOf()) || (lengthOf() == other->lengthOf() && rankOf() >= other->rankOf())) { + if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) { max = this; - min = const_cast(other); + min = const_cast(&other); } else { - max = const_cast(other); + max = const_cast(&other); min = this; } - if(result->dataType() != dataType()) + if(target.dataType() != dataType()) throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!"); - if(!result->isSameShape(max)) + if(!target.isSameShape(max)) throw std::invalid_argument("NDArray::applyBroadcast int method: max and target arrays must have the same shape !"); - if(_dataType != other->_dataType) + if(_dataType != other._dataType) throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !"); std::vector copy(dimensions); @@ -2828,76 +2474,23 @@ void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector& throw std::runtime_error("Tad length mismatch"); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(result->shapeInfo(), copy); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy); // TODO: eventually we want separate tads here - NDArray::prepareSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&target}, {this, &other}); if(max == this) - NativeOpExecutioner::execBroadcastInt( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcastInt( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); else - NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); - registerSpecialUse({result}, {this, other}); + NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list dimensions, const NDArray* tadArray, NDArray* target, ExtraArguments* extraArgs) { +void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list dimensions, const NDArray& tadArray, NDArray& target, ExtraArguments* extraArgs) { std::vector vec(dimensions); applyBroadcast(op, vec, tadArray, target, extraArgs); } -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* other, ExtraArguments *extraArgs) const { - return new NDArray(this->applyTrueBroadcast(op, *other, extraArgs)); -} - -////////////////////////////////////////////////////////////////////////// -// return array which is broadcasted from this and argument array -NDArray* NDArray::broadcast(const NDArray& other) { - // the orders must be the same - char order = ordering(); - if(order != other.ordering()) - throw std::runtime_error("NDArray::broadcast method: arrays have different orders!"); - - // recognize shapes with smaller and bigger rank - Nd4jLong* biggerShapeInfo = nullptr; - Nd4jLong* smallerShapeInfo = nullptr; - int smallerRank, biggerRank; - if (rankOf() > other.rankOf()) { - biggerShapeInfo = _shapeInfo; - biggerRank = shape::rank(_shapeInfo); - smallerShapeInfo = other._shapeInfo; - smallerRank = shape::rank(other._shapeInfo); - } - else { - biggerShapeInfo = other._shapeInfo; - biggerRank = shape::rank(other._shapeInfo); - smallerShapeInfo = _shapeInfo; - smallerRank = shape::rank(_shapeInfo); - } - - // check shapes on consistency - int diff = biggerRank - smallerRank; - for (int i = smallerRank; i<=1; --i) - if(biggerShapeInfo[diff+i] != smallerShapeInfo[i] && biggerShapeInfo[i] != 1 && smallerShapeInfo[i] != 1) - throw std::runtime_error("Broadcast method: arrays have incompatible shapes !"); - - // create and fill ret shapeInfo - Nd4jLong *shapeInfoNew; - ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(biggerRank), Nd4jLong); - memcpy(shapeInfoNew, biggerShapeInfo, shape::shapeInfoByteLength(biggerRank)); - for (int i = smallerRank; i>=1; --i) - if(shapeInfoNew[diff+i] == 1 || smallerShapeInfo[i] == 1) - shapeInfoNew[diff+i] *= smallerShapeInfo[i]; - - ShapeUtils::updateStridesAndType(shapeInfoNew, DataTypeUtils::pickPairwiseResultType(dataType(), other.dataType()), order); - - auto ret = new NDArray(shapeInfoNew, true, getContext()); - - RELEASE(shapeInfoNew, getContext()->getWorkspace()); - - return ret; -} - //////////////////////////////////////////////////////////////////////// void* NDArray::operator new(size_t i) { if (nd4j::memory::MemoryRegistrator::getInstance()->hasWorkspaceAttached()) { @@ -3017,7 +2610,7 @@ bool NDArray::reshapei(const char order, const std::vector& cshape) { } else { NDArray temp(order, shape, dataType(), getContext()); - this->applyTransform(transform::Assign, &temp, nullptr); + this->applyTransform(transform::Assign, temp, nullptr); *this = std::move(temp); } @@ -3049,57 +2642,57 @@ void NDArray::templatedSet(void *buffer, const Nd4jLong xOfsset, nd4j::DataType BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, nd4j::DataType dtype, const void *value), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray* other, NDArray *target, ExtraArguments *extraParams) const{ +void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ if (isS()) throw std::runtime_error("NDArray::applyPairwiseTransform: you can't use this method on String array!"); - if (other->lengthOf() != target->lengthOf()) + if (other.lengthOf() != target.lengthOf()) throw std::invalid_argument("NDArray::applyPairwiseTransform method - lengths of arrays are mismatched"); - if (target->dataType() != this->dataType() && target->dataType() != other->dataType()) + if (target.dataType() != this->dataType() && target.dataType() != other.dataType()) throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !"); - NDArray::prepareSpecialUse({target}, {this, other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr); - NDArray::registerSpecialUse({target}, {this, other}); + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); if (extraParams != nullptr) synchronize("NDArray::applyPairwiseTransform"); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams) const{ +void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ if (isS()) throw std::runtime_error("NDArray::applyPairwiseTransform BoolOps: you can't use this method on String array!"); - if (other->lengthOf() != target->lengthOf()) + if (other.lengthOf() != target.lengthOf()) throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - lengths of arrays are mismatched"); - if (!target->isB()) + if (!target.isB()) throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - result must have bool type"); - if (dataType() != other->dataType()) + if (dataType() != other.dataType()) throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !"); - NDArray::prepareSpecialUse({target}, {this, other}); - NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr); - NDArray::registerSpecialUse({target}, {this, other}); + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); } //////////////////////////////////////////////////////////////////////// - void NDArray::applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!"); - if (other->lengthOf() != target->lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched"); - if (!target->isZ()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type"); - if (dataType() != other->dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !"); +void NDArray::applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ + if (isS()) + throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched"); + if (!target.isZ()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type"); + if (dataType() != other.dataType()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !"); - NDArray::prepareSpecialUse({target}, {this, other}); - NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr); - NDArray::registerSpecialUse({target}, {this, other}); - } + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); +} ////////////////////////////////////////////////////////////////////////// void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams) { - applyPairwiseTransform(op, &other, this, extraParams); + applyPairwiseTransform(op, other, *this, extraParams); } //////////////////////////////////////////////////////////////////////// @@ -3112,41 +2705,31 @@ void NDArray::templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -void NDArray::varianceAlongDimension(nd4j::variance::Ops op, NDArray *target, const bool biasCorrected, const std::vector& dimensions) const { +void NDArray::varianceAlongDimension(nd4j::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const { if (isS()) throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); - if (!target->isR()) + if (!target.isR()) throw std::runtime_error("NDArray::varianceAlongDimension: target array must have FLOAT type"); - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); if(rankOf() == dimensions.size() || dimensions.empty()) - NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), biasCorrected); + NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), biasCorrected); else { std::vector copy(dimensions); auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->buffer(), target->shapeInfo(), target->getSpecialBuffer(), target->specialShapeInfo(), pDims, dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), biasCorrected); + NativeOpExecutioner::execSummaryStats(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.getSpecialBuffer(), target.specialShapeInfo(), pDims, dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), biasCorrected); synchronize("NDArray::varianceAlongDimension"); } - NDArray::registerSpecialUse({target}, {this}); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const { - return varianceAlongDimension(op, biasCorrected, std::vector(dimensions)); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::varianceAlongDimension(nd4j::variance::Ops op, NDArray *target, const bool biasCorrected, const std::initializer_list& dimensions) const { - varianceAlongDimension(op, target, biasCorrected, std::vector(dimensions)); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::varianceAlongDims(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const { +NDArray NDArray::varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const { if (isS()) throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); @@ -3157,87 +2740,27 @@ NDArray NDArray::varianceAlongDims(nd4j::variance::Ops op, const bool biasCorrec auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); NDArray result(newShape, true, getContext()); - this->varianceAlongDimension(op, &result, biasCorrected, dimensions); + this->varianceAlongDimension(op, result, biasCorrected, dimensions); return result; } //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const { - - return new NDArray(this->varianceAlongDims(op, biasCorrected, dimensions)); +NDArray NDArray::varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const { + return varianceAlongDimension(op, biasCorrected, std::vector(dimensions)); } -//////////////////////////////////////////////////////////////////// -// This method assigns values of given NDArray to this one -void NDArray::assign(const NDArray& other, bool allowParallelism) { - - if (this == &other) - return; - - if (other.isEmpty()) { - if (!isEmpty()) { - ArrayOptions::setPropertyBit(shapeInfo(), ARRAY_EMPTY); - syncShape(); - _buffer = std::make_shared(); - _offset = 0; - } - return; - } - - if(isEmpty()) { - *this = other; - return; - } - - if (other.lengthOf() == 1) { - - if(lengthOf() == 1) { - NDArray::preparePrimaryUse({this}, {&other}); - BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {&other}); - this->syncToDevice(); - } - else { - if (dataType() != other.dataType()) { - auto tmp = other.cast(dataType()); - NDArray::prepareSpecialUse({this}, {tmp}); - NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp->getBuffer(), tmp->getShapeInfo(), tmp->getSpecialBuffer(), tmp->getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {}); - delete tmp; - } - else { - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&other}); - } - } - } - else { - if (other.lengthOf() != lengthOf()) { - auto shapeThis = ShapeUtils::shapeAsString(this); - auto shapeThat = ShapeUtils::shapeAsString(&other); - nd4j_printf("Can't assign new value to the array: this shape %s; other shape: %s\n", shapeThis.c_str(), shapeThat.c_str()); - throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched"); - } - - // memcpy is allowed only for same order && same ews (being equal to 1) - if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1) - copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT()); - else { - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&other}); - } - } +//////////////////////////////////////////////////////////////////////// +void NDArray::varianceAlongDimension(nd4j::variance::Ops op, NDArray &target, const bool biasCorrected, const std::initializer_list& dimensions) const { + varianceAlongDimension(op, target, biasCorrected, std::vector(dimensions)); } //////////////////////////////////////////////////////////////////////// // This method returns new copy of this NDArray, optionally in different order -NDArray* NDArray::dup(const char newOrder) const { +NDArray NDArray::dup(const char newOrder) const { if (isEmpty()) - return NDArrayFactory::empty_(dataType(), getContext()); + return NDArrayFactory::empty(dataType(), getContext()); char order = newOrder == 'a' ? ordering() : newOrder; @@ -3248,12 +2771,12 @@ NDArray* NDArray::dup(const char newOrder) const { for (int e = 0; e < lengthOf(); e++) strings[e] = this->e(e); - auto result = NDArrayFactory::string_(order, getShapeAsVector(), strings, getContext()); + auto result = NDArrayFactory::string(order, getShapeAsVector(), strings, getContext()); return result; } - auto result = new NDArray(order, isScalar() ? std::vector({0}) : getShapeAsVector(), dataType(), getContext()); - result->assign(*this); + NDArray result(order, isScalar() ? std::vector({0}) : getShapeAsVector(), dataType(), getContext()); + result.assign(*this); return result; } @@ -3432,87 +2955,72 @@ NDArray NDArray::e(const Nd4jLong i) const { ////////////////////////////////////////////////////////////////////////// // perform array transformation -void NDArray::applyTransform(nd4j::transform::FloatOps op, NDArray *target, ExtraArguments *extraParams) { +void NDArray::applyTransform(nd4j::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyTransform FloatOps: you can't use this method on String array!"); - if (target == nullptr) - target = this; - - if (!target->isR()) + if (!target.isR()) throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types"); - NDArray::prepareSpecialUse({target}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(nd4j::transform::AnyOps op, NDArray *target, ExtraArguments *extraParams) { +void NDArray::applyTransform(nd4j::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyTransform AnyOps: you can't use this method on String array!"); - if (target == nullptr) - target = this; - - NDArray::prepareSpecialUse({target}, {this}); - NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(nd4j::transform::SameOps op, NDArray *target, ExtraArguments *extraParams) { +void NDArray::applyTransform(nd4j::transform::SameOps op, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyTransform SameOps: you can't use this method on String array!"); - if (target == nullptr) - target = this; - - if (target->dataType() != dataType()) + if (target.dataType() != dataType()) throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array"); - NDArray::prepareSpecialUse({target}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(nd4j::transform::StrictOps op, NDArray *target, ExtraArguments *extraParams) { +void NDArray::applyTransform(nd4j::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyTransform StrictOps: you can't use this method on String array!"); - if (target == nullptr) - target = this; - - if (!this->isR() || !target->isR() || (this->dataType() != target->dataType())) + if (!this->isR() || !target.isR() || (this->dataType() != target.dataType())) throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !"); - NDArray::prepareSpecialUse({target}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(nd4j::transform::BoolOps op, NDArray *target, ExtraArguments *extraParams) { +void NDArray::applyTransform(nd4j::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyTransform BoolOps: you can't use this method on String array!"); - if (target == nullptr) - target = this; - - if (!target->isB()) + if (!target.isB()) throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types"); - NDArray::prepareSpecialUse({target}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) const { +NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) const & { if (isS()) throw std::runtime_error("NDArray::transform FloatOps: you can't use this method on String array!"); @@ -3526,7 +3034,19 @@ NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) cons } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) const { +NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) && { + if (isS()) + throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) const & { if (isS()) throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); @@ -3540,7 +3060,19 @@ NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) const } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) const { +NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) && { + if (isS()) + throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) const & { if (!this->isR()) throw std::runtime_error("Source array must have one of FLOAT types"); @@ -3554,7 +3086,19 @@ NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) con } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const { +NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) && { + if (!this->isR()) + throw std::runtime_error("Source array must have one of FLOAT types"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformStrict(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const & { if (isS()) throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); @@ -3567,151 +3111,159 @@ NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const return result; } +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) && { + if (isS()) + throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + ////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams) { +void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!"); - if (scalar->lengthOf() != 1) + if (scalar.lengthOf() != 1) throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!"); - if(target == nullptr) - target = this; - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar->getShapeInfo()) && !(target->dataType() == dataType() || target->dataType() == scalar->dataType())) + + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar.getShapeInfo()) && !(target.dataType() == dataType() || target.dataType() == scalar.dataType())) throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!"); - NDArray::prepareSpecialUse({target}, {this, scalar}); - NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr); - NDArray::registerSpecialUse({target}, {this, scalar}); + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.getBuffer(), scalar.getShapeInfo(), scalar.getSpecialBuffer(), scalar.getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); } -//////////////////////////////////////////////////////////////////////// -template -void NDArray::applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray *target, ExtraArguments *extraParams) { - - auto scalarArr = NDArrayFactory::create(dataType(), scalar, this->getContext()); - applyScalarArr(op, &scalarArr, target, extraParams); -} - -template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const double scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float16 scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDArray *target, ExtraArguments *extraParams); - ////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { +void NDArray::applyScalarArr(nd4j::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { if (isS()) throw std::runtime_error("NDArray::applyScalarArr BoolOps: you can't use this method on String array!"); - if (target == nullptr || !target->isB()) - throw std::invalid_argument("NDArray::applyScalarArr bool method: target is nullptr or has not bool type!"); - if (dataType() != scalar->dataType()) { - nd4j_printf("NDArray::applyScalarArr BoolOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar->dataType()); + if (!target.isB()) + throw std::invalid_argument("NDArray::applyScalarArr bool method: target has not bool type!"); + if (dataType() != scalar.dataType()) { + nd4j_printf("NDArray::applyScalarArr BoolOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!"); } - NDArray::prepareSpecialUse({target}, {this, scalar}); - NativeOpExecutioner::execScalarBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr); - NDArray::registerSpecialUse({target}, {this, scalar}); + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalarBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.getBuffer(), scalar.getShapeInfo(), scalar.getSpecialBuffer(), scalar.getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyScalarArr(nd4j::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!"); + + if (target.dataType() != this->dataType()) + throw std::invalid_argument("NDArray::applyScalarArr int method: target has not bool type!"); + if (dataType() != scalar.dataType()) { + nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); + throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!"); + } + + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalarInt(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.getBuffer(), scalar.getShapeInfo(), scalar.getSpecialBuffer(), scalar.getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); } //////////////////////////////////////////////////////////////////////// template -void NDArray::applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray *target, ExtraArguments *extraParams) const { +void NDArray::applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams) const { - NDArray scalarArr = NDArrayFactory::create(scalar, getContext()); - applyScalarArr(op, &scalarArr, target, extraParams); + NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext()); + applyScalarArr(op, scalarArr, target, extraParams); } -template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; - - -////////////////////////////////////////////////////////////////////////// - void NDArray::applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!"); - - if (target == nullptr || target->dataType() != this->dataType()) - throw std::invalid_argument("NDArray::applyScalarArr int method: target is nullptr or has not bool type!"); - if (dataType() != scalar->dataType()) { - nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar->dataType()); - throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!"); - } - - NDArray::prepareSpecialUse({target}, {this, scalar}); - NativeOpExecutioner::execScalarInt(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr); - NDArray::registerSpecialUse({target}, {this, scalar}); - } +template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; //////////////////////////////////////////////////////////////////////// - template - void NDArray::applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray *target, ExtraArguments *extraParams) const { - - NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext()); - applyScalarArr(op, &scalarArr, target, extraParams); - } - - template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; +template +void NDArray::applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams) { + auto scalarArr = NDArrayFactory::create(dataType(), scalar, this->getContext()); + applyScalarArr(op, scalarArr, target, extraParams); +} +template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const double scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float16 scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDArray &target, ExtraArguments *extraParams); //////////////////////////////////////////////////////////////////////// -void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const std::vector& dimensions, const ExtraArguments *extraParams) const { +template +void NDArray::applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray &target, ExtraArguments *extraParams) const { + + NDArray scalarArr = NDArrayFactory::create(scalar, getContext()); + applyScalarArr(op, scalarArr, target, extraParams); +} + +template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; + +//////////////////////////////////////////////////////////////////////// +void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams) const { if (isS()) throw std::runtime_error("NDArray::applyIndexReduce: you can't use this method on String array!"); - if (target->dataType() != nd4j::DataType::INT64 && target->dataType() != nd4j::DataType::INT32) + if (target.dataType() != nd4j::DataType::INT64 && target.dataType() != nd4j::DataType::INT32) throw std::runtime_error("NDArray::applyIndexReduce operations return INT32/INT64"); void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(this->dataType()) : nullptr; - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); - if (target->lengthOf() == 1) { - NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo()); + if (target.lengthOf() == 1) { + NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); } else { std::vector copy = dimensions; shape::checkDimensions(rankOf(), copy); auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(getShapeInfo(), copy); - NativeOpExecutioner::execIndexReduce(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + NativeOpExecutioner::execIndexReduce(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); synchronize("NDArray::applyIndexReduce"); } - registerSpecialUse({target}, {this}); + registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// // reduce dimensions in this array relying on index operations -NDArray* NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments* extraParams ) const { +NDArray NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments* extraParams ) const { std::vector copy = dimensions; auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, false, false, getContext()->getWorkspace()); - auto result = new NDArray(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); applyIndexReduce(op, result, copy, extraParams); @@ -3720,10 +3272,11 @@ NDArray* NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector< //////////////////////////////////////////////////////////////////////// // apply reduce3 operations to this and other array, return result in new output array -NDArray* NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, const ExtraArguments* extraParams) const { +NDArray NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams) const { + if (isS()) throw std::runtime_error("NDArray::applyReduce3 method: you can't use this method on String array!"); - if(dataType() != other->dataType()) + if(dataType() != other.dataType()) throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); // check shapes consistency if(!isSameShape(other)) @@ -3731,75 +3284,75 @@ NDArray* NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, cons // create shapeInfo for scalar auto newShape = ShapeBuilders::createScalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()), getContext()->getWorkspace()); // create output array (scalar) - auto result = new NDArray(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); RELEASE(newShape, getContext()->getWorkspace()); // create dynamic array of extra parameters if array extraParams is empty (==nullptr) void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - NDArray::prepareSpecialUse({result}, {this, other}); - NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo()); - NDArray::registerSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&result}, {this, &other}); + NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this, &other}); return result; } //////////////////////////////////////////////////////////////////////// // apply reduce3 (exec) operations to this and other array, return result in new output array -NDArray* NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, const std::vector& dimensions, const ExtraArguments* extraParams) const { +NDArray NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { if (isS()) throw std::runtime_error("NDArray::applyReduce3: you can't use this method on String array!"); - if(dataType() != other->dataType()) + if(dataType() != other.dataType()) throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); std::vector copy(dimensions); shape::checkDimensions(rankOf(), copy); - shape::checkDimensions(other->rankOf(), copy); + shape::checkDimensions(other.rankOf(), copy); auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); - auto result = new NDArray(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); // create temporary dynamic array of extra parameters if array extraParams is empty (==nullptr) void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - NDArray::prepareSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&result}, {this, &other}); // perform calculations - if(rankOf() == copy.size() && other->rankOf() == copy.size()) { - NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo()); + if(rankOf() == copy.size() && other.rankOf() == copy.size()) { + NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); } else { auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(getShapeInfo(), copy); - auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(other->getShapeInfo(), copy); + auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(other.getShapeInfo(), copy); if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo()) || (packX.numberOfTads() != packY.numberOfTads() && packX.numberOfTads() != 1 && packY.numberOfTads() != 1)) throw std::runtime_error("NDArray::applyReduce3 cuda method: arrays tads are inconsistent !"); - NativeOpExecutioner::execReduce3(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + NativeOpExecutioner::execReduce3(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); } - registerSpecialUse({result}, {this, other}); + registerSpecialUse({&result}, {this, &other}); return result; } //////////////////////////////////////////////////////////////////////// // apply reduce3 (execAll) operations to this and other array, return result in new output array -NDArray* NDArray::applyAllReduce3(nd4j::reduce3::Ops op, const NDArray *other, const std::vector& dimensions, const ExtraArguments* extraParams) const { +NDArray NDArray::applyAllReduce3(nd4j::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { if (isS()) throw std::runtime_error("NDArray::applyAllReduce3: you can't use this method on String array!"); - if(dataType() != other->dataType()) + if(dataType() != other.dataType()) throw std::runtime_error("NDArray::applyAllReduce3 method: the types of this and other arrays must be the same !"); // be careful, copy array may undergo changes (sort, transformation of negative dimensions to positive, duplicates removing ) std::vector copy(dimensions); shape::checkDimensions(rankOf(), copy); - shape::checkDimensions(other->rankOf(), copy); + shape::checkDimensions(other.rankOf(), copy); auto packX = ConstantTadHelper::getInstance()->tadForDimensions(getShapeInfo(), copy); - auto packY = ConstantTadHelper::getInstance()->tadForDimensions(other->getShapeInfo(), copy); + auto packY = ConstantTadHelper::getInstance()->tadForDimensions(other.getShapeInfo(), copy); // check tads shapes if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo())) @@ -3809,145 +3362,145 @@ NDArray* NDArray::applyAllReduce3(nd4j::reduce3::Ops op, const NDArray *other, c auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataTypeUtils::pickFloatingType(dataType()), 'c', {packX.numberOfTads(), packY.numberOfTads()}); // create output array - auto result = new NDArray(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); // create dynamic array of extra parameters if array extraParams is empty (==nullptr) void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - NDArray::prepareSpecialUse({result}, {this, other}); - NativeOpExecutioner::execReduce3All(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - NDArray::registerSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&result}, {this, &other}); + NativeOpExecutioner::execReduce3All(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + NDArray::registerSpecialUse({&result}, {this, &other}); return result; } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(nd4j::reduce::FloatOps op, NDArray* target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { +void NDArray::reduceAlongDimension(nd4j::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { if (isS()) throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: you can't use this method on String array!"); - if (target == nullptr || !target->isR()) + if (!target.isR()) throw std::invalid_argument("NDArray::reduceAlongDimension FloatOps: requires target array to be present and have type form real space!"); std::vector copy(dimensions); if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target->ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target->getShapeInfo())) + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.getShapeInfo())) throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: wrong target shape!"); } - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(),nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo()); + NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(),nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo()); } else { auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(getShapeInfo(), copy); - NativeOpExecutioner::execReduceFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), copy.data(), copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + NativeOpExecutioner::execReduceFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), copy.data(), copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); } synchronize("NDArray::reduceAlongDimension FloatOps"); - NDArray::registerSpecialUse({target}, {this}); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, NDArray* target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { +void NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { if (isS()) throw std::runtime_error("NDArray::reduceAlongDimension SameOps: you can't use this method on String array!"); - if (target == nullptr || target->dataType() != dataType()) + if (target.dataType() != dataType()) throw std::runtime_error("NDArray::reduceAlongDimension SameOps: requires target array to be present and have same dtype as input"); std::vector copy(dimensions); if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target->ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target->getShapeInfo())) + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.getShapeInfo())) throw std::runtime_error("NDArray::reduceAlongDimension SameOps: wrong target shape!"); } - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo()); + NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo()); } else { //if (!isEmpty()) { auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy); - NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); } synchronize("NDArray::reduceAlongDimension SameOps"); - NDArray::registerSpecialUse({target}, {this}); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, NDArray* target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { +void NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { if (isS()) throw std::runtime_error("NDArray::reduceAlongDimension LongOps: you can't use this method on String array!"); - if (target == nullptr || target->dataType() != DataType::INT64) + if (target.dataType() != DataType::INT64) throw std::runtime_error("NDArray::reduceAlongDimension LongOps: requires target array to be present and have type of INT64"); std::vector copy(dimensions); if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target->ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target->getShapeInfo())) + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.getShapeInfo())) throw std::runtime_error("NDArray::reduceAlongDimension LongOps: wrong target shape!"); } - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo()); + NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo()); } else { auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy); - NativeOpExecutioner::execReduceLong(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + NativeOpExecutioner::execReduceLong(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); } synchronize("NDArray::reduceAlongDimension LongOps"); - NDArray::registerSpecialUse({target}, {this}); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, NDArray* target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { +void NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { if (isS()) throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: you can't use this method on String array!"); - if (target == nullptr || !target->isB()) + if (!target.isB()) throw std::invalid_argument("NDArray::reduceAlongDimension BoolOps cuda: requires target array to be present and have BOOL type!"); std::vector copy(dimensions); if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target->ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target->getShapeInfo())) + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.getShapeInfo())) throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!"); } - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo()); + NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo()); } else { auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy); - NativeOpExecutioner::execReduceBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + NativeOpExecutioner::execReduceBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); } synchronize("NDArray::reduceAlongDimension LongOps"); - NDArray::registerSpecialUse({target}, {this}); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// @@ -4102,152 +3655,152 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) { } ////////////////////////////////////////////////////////////////////////// -void NDArray::addRowVector(const NDArray *row, NDArray *target) const { +void NDArray::addRowVector(const NDArray& row, NDArray& target) const { if (isS()) throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target->rankOf() != 2 || rows() != target->rows() || columns() != target->columns() || !row->isRowVector() || columns() != row->lengthOf()) + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row->dataType()) && !(isR() && row->isR() && target->isR())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); int dimension = 1; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({target}, {this, row}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row->getBuffer(), row->getShapeInfo(), row->getSpecialBuffer(), row->getSpecialShapeInfo(), target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this, row}); + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row.getBuffer(), row.getShapeInfo(), row.getSpecialBuffer(), row.getSpecialShapeInfo(), target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::subRowVector(const NDArray *row, NDArray *target) const { +void NDArray::subRowVector(const NDArray& row, NDArray& target) const { if (isS()) throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target->rankOf() != 2 || rows() != target->rows() || columns() != target->columns() || !row->isRowVector() || columns() != row->lengthOf()) + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row->dataType()) && !(isR() && row->isR() && target->isR())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); int dimension = 1; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({target}, {this, row}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row->getBuffer(), row->getShapeInfo(), row->getSpecialBuffer(), row->getSpecialShapeInfo(), target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this, row}); + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row.getBuffer(), row.getShapeInfo(), row.getSpecialBuffer(), row.getSpecialShapeInfo(), target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::mulRowVector(const NDArray *row, NDArray *target) const { +void NDArray::mulRowVector(const NDArray &row, NDArray &target) const { if (isS()) throw std::runtime_error("NDArray::mulRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target->rankOf() != 2 || rows() != target->rows() || columns() != target->columns() || !row->isRowVector() || columns() != row->columns()) + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row->dataType())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) throw std::invalid_argument("NDArray::mulRowVector: wrong type of target array !"); int dimension = 1; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({target}, {this, row}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row->getBuffer(), row->getShapeInfo(), row->getSpecialBuffer(), row->getSpecialShapeInfo(), target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this, row}); + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row.getBuffer(), row.getShapeInfo(), row.getSpecialBuffer(), row.getSpecialShapeInfo(), target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::divRowVector(const NDArray *row, NDArray *target) const { +void NDArray::divRowVector(const NDArray &row, NDArray &target) const { if (isS()) throw std::runtime_error("NDArray::divRowVector: you can't use this method on String array!"); - if (row->isB()) + if (row.isB()) throw std::runtime_error("NDArray::divRowVector: you can't divide by bool row!"); - if (rankOf() != 2 || target->rankOf() != 2 || rows() != target->rows() || columns() != target->columns() || !row->isRowVector() || columns() != row->columns()) + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row->dataType())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) throw std::invalid_argument("NDArray::divRowVector: wrong type of target array !"); int dimension = 1; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({target}, {this, row}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row->getBuffer(), row->getShapeInfo(), row->getSpecialBuffer(), row->getSpecialShapeInfo(), target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this, row}); + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row.getBuffer(), row.getShapeInfo(), row.getSpecialBuffer(), row.getSpecialShapeInfo(), target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); } ////////////////////////////////////////////////////////////////////////// // This method adds given row to all rows in this NDArray, this array becomes affected -void NDArray::addiRowVector(const NDArray *row) { +void NDArray::addiRowVector(const NDArray& row) { if (isS()) throw std::runtime_error("NDArray::addiRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || !row->isRowVector() || columns() != row->lengthOf()) + if (rankOf() != 2 || !row.isRowVector() || columns() != row.lengthOf()) throw std::invalid_argument("NDArray::addiRowVector: wrong arguments !"); int dimension = 1; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({this}, {row}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row->getBuffer(), row->getShapeInfo(), row->getSpecialBuffer(), row->getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {row}); + NDArray::prepareSpecialUse({this}, {&row}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row.getBuffer(), row.getShapeInfo(), row.getSpecialBuffer(), row.getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({this}, {&row}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::addColumnVector(const NDArray *column, NDArray *target) const { +void NDArray::addColumnVector(const NDArray &column, NDArray &target) const { if (isS()) throw std::runtime_error("NDArray::addColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || target->rankOf() != 2 || rows() != target->rows() || columns() != target->columns() || !column->isColumnVector() || rows() != column->lengthOf()) + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !column.isColumnVector() || rows() != column.lengthOf()) throw std::invalid_argument("NDArray::addColumnVector: wrong arguments !"); - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), column->dataType())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), column.dataType())) throw std::invalid_argument("NDArray::addColumnVector: wrong type of target array !"); int dimension = 0; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({target}, {this, column}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column->getBuffer(), column->getShapeInfo(), column->getSpecialBuffer(), column->getSpecialShapeInfo(), target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this, column}); + NDArray::prepareSpecialUse({&target}, {this, &column}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column.getBuffer(), column.getShapeInfo(), column.getSpecialBuffer(), column.getSpecialShapeInfo(), target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &column}); } ////////////////////////////////////////////////////////////////////////// // This method adds given column to all columns in this NDArray, this array becomes affected -void NDArray::addiColumnVector(const NDArray *column) { +void NDArray::addiColumnVector(const NDArray &column) { if (isS()) throw std::runtime_error("NDArray::addiColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || !column->isColumnVector() || rows() != column->lengthOf()) + if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) throw std::invalid_argument("NDArray::addiColumnVector: wrong arguments !"); int dimension = 0; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({this}, {column}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column->getBuffer(), column->getShapeInfo(), column->getSpecialBuffer(), column->getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {column}); + NDArray::prepareSpecialUse({this}, {&column}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column.getBuffer(), column.getShapeInfo(), column.getSpecialBuffer(), column.getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({this}, {&column}); } ////////////////////////////////////////////////////////////////////////// // This method multiplies each column of this array by given argument-column, this array becomes affected -void NDArray::muliColumnVector(const NDArray *column) { +void NDArray::muliColumnVector(const NDArray& column) { if (isS()) throw std::runtime_error("NDArray::muliColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || !column->isColumnVector() || rows() != column->lengthOf()) + if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) throw std::invalid_argument("NDArray::muliColumnVector: wrong arguments !"); int dimension = 0; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({this}, {column}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column->getBuffer(), column->getShapeInfo(), column->getSpecialBuffer(), column->getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {column}); + NDArray::prepareSpecialUse({this}, {&column}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column.getBuffer(), column.getShapeInfo(), column.getSpecialBuffer(), column.getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({this}, {&column}); } ////////////////////////////////////////////////////////////////////////// @@ -4278,8 +3831,8 @@ bool NDArray::permutei(const Nd4jLong* dimensions, const int rank) { } //////////////////////////////////////////////////////////////////////// -ResultSet* NDArray::multipleTensorsAlongDimension(const std::vector &indices, const std::vector &dimensions) const { - auto result = new ResultSet(); +ResultSet NDArray::multipleTensorsAlongDimension(const std::vector &indices, const std::vector &dimensions) const { + ResultSet result; if (indices.size() == 0) return result; @@ -4296,19 +3849,19 @@ ResultSet* NDArray::multipleTensorsAlongDimension(const std::vector &indice } auto array = new NDArray(getDataBuffer(), ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset()); - result->push_back(array); + result.push_back(array); } return result; } //////////////////////////////////////////////////////////////////////// -ResultSet* NDArray::allTensorsAlongDimension(const std::initializer_list& dimensions) const { +ResultSet NDArray::allTensorsAlongDimension(const std::initializer_list& dimensions) const { return allTensorsAlongDimension(std::vector(dimensions)); } //////////////////////////////////////////////////////////////////////// -ResultSet* NDArray::allExamples() const { +ResultSet NDArray::allExamples() const { std::vector dimensions(rankOf() - 1); for (int e = 1; e < rankOf(); e++) dimensions[e-1] = e; @@ -4338,7 +3891,7 @@ NDArray NDArray::ulike() { } //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::diagonal(const char type) const { +NDArray NDArray::diagonal(const char type) const { if (isS()) throw std::runtime_error("NDArray::diagonal: you can't use this method on String array!"); @@ -4386,7 +3939,7 @@ NDArray* NDArray::diagonal(const char type) const { ArrayOptions::setDataType(outShapeInfo, this->dataType()); - auto result = new NDArray(_buffer, ShapeDescriptor(outShapeInfo), getContext(), getBufferOffset()); + NDArray result(_buffer, ShapeDescriptor(outShapeInfo), getContext(), getBufferOffset()); RELEASE(outShapeInfo, getContext()->getWorkspace()); @@ -4394,9 +3947,9 @@ NDArray* NDArray::diagonal(const char type) const { } //////////////////////////////////////////////////////////////////////// -ResultSet* NDArray::allTensorsAlongDimension(const std::vector &dimensions) const { +ResultSet NDArray::allTensorsAlongDimension(const std::vector &dimensions) const { - auto result = new ResultSet(); + ResultSet result; if(dimensions.size() == 0) return result; @@ -4411,14 +3964,14 @@ ResultSet* NDArray::allTensorsAlongDimension(const std::vector &dimensions) for (int idx = 0; idx < numTads; idx++ ) { auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset()); array->_isView = true; - result->push_back(array); + result.push_back(array); } return result; } ////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::vector& dimensions) const { +NDArray NDArray::tensorAlongDimension(Nd4jLong index, const std::vector& dimensions) const { std::vector copy(dimensions); shape::checkDimensions(rankOf(), copy); @@ -4430,12 +3983,17 @@ NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::vector& d auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy); - auto array = new NDArray(_buffer, ShapeDescriptor(packX.primaryShapeInfo()), getContext(), packX.primaryOffsets()[index] + getBufferOffset()); - array->_isView = true; + NDArray array(_buffer, ShapeDescriptor(packX.primaryShapeInfo()), getContext(), packX.primaryOffsets()[index] + getBufferOffset()); + array._isView = true; return array; } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::tensorAlongDimension(Nd4jLong index, const std::initializer_list& dimensions) const { + return tensorAlongDimension(index, std::vector(dimensions)); +} + //////////////////////////////////////////////////////////////////////// // operator returns sub-array with buffer pointing at this->_buffer + certain offset NDArray NDArray::operator()(const std::vector& idx, const bool keepUnitiesInShape, const bool isStrided) const { @@ -4606,6 +4164,539 @@ void NDArray::setShapeInfo(const ConstantDataBuffer& shapeBuffer) { _dataType = ArrayOptions::dataType(_shapeInfo); } +/////////////////////////////////////////////////////////////////////// +// addition operator array + scalar +template +NDArray operator+(NDArray&& arr, const T& scalar) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr + scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Add, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.buffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template ND4J_EXPORT NDArray operator+(NDArray&& arr, const double& scalar); +template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float& scalar); +template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator+(NDArray&& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator+(NDArray&& arr, const int& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator+(const NDArray& arr, const T& scalar) { + + if (arr.isS()) + throw std::runtime_error("operator+(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Add, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template ND4J_EXPORT NDArray operator+(const NDArray& arr, const double& scalar); +template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float& scalar); +template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator+(const NDArray& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator+(const NDArray& arr, const int& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator+(const T& scalar, NDArray&& arr) { + return std::move(arr) + scalar; +} +template ND4J_EXPORT NDArray operator+(const double& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator+(const float& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator+(const float16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator+(const bfloat16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator+(const int& scalar, NDArray&& arr); + + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator+(const T& scalar, const NDArray& arr) { + return arr + scalar; +} +template ND4J_EXPORT NDArray operator+(const double& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator+(const float& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator+(const int& scalar, const NDArray& arr); + +/////////////////////////////////////////////////////////////////////// +// addition operator array - scalar +template +NDArray operator-(NDArray&& arr, const T& scalar) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr - scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Subtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.buffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template ND4J_EXPORT NDArray operator-(NDArray&& arr, const double& scalar); +template ND4J_EXPORT NDArray operator-(NDArray&& arr, const float& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator-(const NDArray& arr, const T& scalar) { + + if (arr.isS()) + throw std::runtime_error("operator-(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Subtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template ND4J_EXPORT NDArray operator-(const NDArray& arr, const double& scalar); +template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float& scalar); +template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator-(const NDArray& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator-(const NDArray& arr, const int& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator-(const T& scalar, NDArray&& arr) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(scalar - arr); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator-(const T& scalar, NDArray&& arr): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.getBuffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); + +} +template ND4J_EXPORT NDArray operator-(const double& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator-(const float& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator-(const float16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator-(const bfloat16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator-(const int& scalar, NDArray&& arr); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator-(const T& scalar, const NDArray& arr) { + + if (arr.isS()) + throw std::runtime_error("operator-(const T& scalar, const NDArray& arr): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template ND4J_EXPORT NDArray operator-(const double& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator-(const float& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator-(const int& scalar, const NDArray& arr); + +/////////////////////////////////////////////////////////////////////// +// addition operator array + scalar +template +NDArray operator*(NDArray&& arr, const T& scalar) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr * scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Multiply, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.buffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const double& scalar); +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float& scalar); +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const int& scalar); +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator*(const NDArray& arr, const T& scalar) { + + if (arr.isS()) + throw std::runtime_error("operator*(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Multiply, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} + +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const double& scalar); +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float& scalar); +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const int& scalar); +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator*(const T& scalar, NDArray&& arr) { + return std::move(arr) * scalar; +} +template ND4J_EXPORT NDArray operator*(const double& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator*(const float& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator*(const float16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator*(const int& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator*(const long long& scalar, NDArray&& arr); + + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator*(const T& scalar, const NDArray& arr) { + return arr * scalar; +} +template ND4J_EXPORT NDArray operator*(const double& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator*(const float& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator*(const int& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator*(const long long& scalar, const NDArray& arr); + +/////////////////////////////////////////////////////////////////////// +template +NDArray operator/(NDArray&& arr, const T& scalar) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr / scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Divide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.buffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template ND4J_EXPORT NDArray operator/(NDArray&& arr, const double& scalar); +template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float& scalar); +template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator/(NDArray&& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator/(NDArray&& arr, const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator/(const NDArray& arr, const T& scalar) { + + if (arr.isS()) + throw std::runtime_error("operator/(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Divide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const double& scalar); +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float& scalar); +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const int& scalar); +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator/(const T& scalar, NDArray&& arr) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(scalar / arr); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator/(const T& scalar, NDArray&& arr): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.getBuffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); + +} +template ND4J_EXPORT NDArray operator/(const double& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator/(const float& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator/(const float16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator/(const bfloat16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator/(const int& scalar, NDArray&& arr); + + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator/(const T& scalar, const NDArray& arr) { + + if (arr.isS()) + throw std::runtime_error("operator/(const T& scalar, const NDArray& arr): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template ND4J_EXPORT NDArray operator/(const double& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator/(const float& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator/(const int& scalar, const NDArray& arr); + +//////////////////////////////////////////////////////////////////////// +// addition operator array + array +template +NDArray operator+(T1&& arr1, T2&& arr2) { + + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator+(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw nd4j::datatype_exception::build("operator+(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), "operator+(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.getShapeInfo(), arr2.getShapeInfo()), false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), nd4j::pairwise::Add, arr1.getBuffer(), arr1.getShapeInfo(), arr1.getSpecialBuffer(), arr1.getSpecialShapeInfo(), arr2.getBuffer(), arr2.getShapeInfo(), arr2.getSpecialBuffer(), arr2.getSpecialShapeInfo(), result->buffer(), result->getShapeInfo(), result->specialBuffer(), result->getSpecialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), std::forward(arr2)); +} +template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator+(NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator+(const NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator+(NDArray&& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray&& arr2); + +//////////////////////////////////////////////////////////////////////// +// addition operator array - array +template +NDArray operator-(T1&& arr1, T2&& arr2) { + + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator-(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw nd4j::datatype_exception::build("operator-(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), "operator-(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.getShapeInfo(), arr2.getShapeInfo()), false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), nd4j::pairwise::Subtract, arr1.getBuffer(), arr1.getShapeInfo(), arr1.getSpecialBuffer(), arr1.getSpecialShapeInfo(), arr2.getBuffer(), arr2.getShapeInfo(), arr2.getSpecialBuffer(), arr2.getSpecialShapeInfo(), result->buffer(), result->getShapeInfo(), result->specialBuffer(), result->getSpecialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), std::forward(arr2)); +} +template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator-(NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator-(const NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator-(NDArray&& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray&& arr2); + +//////////////////////////////////////////////////////////////////////// +// multiplication operator array*array +template +NDArray operator*(T1&& arr1, T2&& arr2) { + + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator*(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw nd4j::datatype_exception::build("operator*(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), "operator*(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.getShapeInfo(), arr2.getShapeInfo()), false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), nd4j::pairwise::Multiply, arr1.getBuffer(), arr1.getShapeInfo(), arr1.getSpecialBuffer(), arr1.getSpecialShapeInfo(), arr2.getBuffer(), arr2.getShapeInfo(), arr2.getSpecialBuffer(), arr2.getSpecialShapeInfo(), result->buffer(), result->getShapeInfo(), result->specialBuffer(), result->getSpecialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), std::forward(arr2)); +} +template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator*(NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator*(const NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator*(NDArray&& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray&& arr2); + +//////////////////////////////////////////////////////////////////////// +// multiplication operator array*array +template +NDArray operator/(T1&& arr1, T2&& arr2) { + + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator/(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw nd4j::datatype_exception::build("operator/(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), "operator/(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.getShapeInfo(), arr2.getShapeInfo()), false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), nd4j::pairwise::Divide, arr1.getBuffer(), arr1.getShapeInfo(), arr1.getSpecialBuffer(), arr1.getSpecialShapeInfo(), arr2.getBuffer(), arr2.getShapeInfo(), arr2.getSpecialBuffer(), arr2.getSpecialShapeInfo(), result->buffer(), result->getShapeInfo(), result->specialBuffer(), result->getSpecialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), std::forward(arr2)); +} +template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator/(NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator/(const NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator/(NDArray&& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray&& arr2); + /* #ifndef __CLION_IDE__ diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index ea2b0b8a5..cd6274dfb 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -68,6 +68,7 @@ bool verbose = false; #include #include #include +#include #include #include #include @@ -75,6 +76,9 @@ bool verbose = false; #include #include #include +#include + +typedef nd4j::InteropDataBuffer OpaqueDataBuffer; extern "C" { @@ -118,11 +122,9 @@ ND4J_EXPORT void setTADThreshold(int num); */ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); /** * @@ -137,13 +139,10 @@ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); /** * @@ -160,28 +159,20 @@ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, ND4J_EXPORT void execBroadcast( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); ND4J_EXPORT void execBroadcastBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); /** * @@ -198,23 +189,17 @@ ND4J_EXPORT void execBroadcastBool( ND4J_EXPORT void execPairwiseTransform( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); ND4J_EXPORT void execPairwiseTransformBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); /** @@ -228,36 +213,28 @@ ND4J_EXPORT void execPairwiseTransformBool( */ ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); /** * @@ -270,46 +247,34 @@ ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); /** * @@ -324,13 +289,10 @@ ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); /** * @@ -343,13 +305,10 @@ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); /** * * @param opNum @@ -365,30 +324,22 @@ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets); ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets); @@ -405,22 +356,16 @@ ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hSscalarShapeInfo, - void *dScalar, Nd4jLong *dSscalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo, void *extraParams); ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hSscalarShapeInfo, - void *dScalar, Nd4jLong *dSscalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo, void *extraParams); /** @@ -432,11 +377,9 @@ ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected); /** * @@ -449,11 +392,9 @@ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected); /** * @@ -468,13 +409,10 @@ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, bool biasCorrected, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); @@ -490,42 +428,32 @@ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); /** @@ -543,29 +471,21 @@ ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); @@ -904,10 +824,8 @@ ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr); * @param zTadOffsets */ ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers, - void *x, Nd4jLong *xShapeInfo, - void *dx, Nd4jLong *dxShapeInfo, - void *z, Nd4jLong *zShapeInfo, - void *dz, Nd4jLong *dzShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dzShapeInfo, Nd4jLong n, Nd4jLong *indexes, Nd4jLong *tadShapeInfo, @@ -1086,8 +1004,7 @@ ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer, void *extraArguments); /** @@ -1106,12 +1023,9 @@ ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hX, Nd4jLong *hXShapeBuffer, - void *dX, Nd4jLong *dXShapeBuffer, - void *hY, Nd4jLong *hYShapeBuffer, - void *dY, Nd4jLong *dYShapeBuffer, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeBuffer, Nd4jLong *dYShapeBuffer, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer, void *extraArguments); /** @@ -1128,10 +1042,8 @@ ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hX, Nd4jLong *hXShapeBuffer, - void *dX, Nd4jLong *dXShapeBuffer, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer, void *extraArguments); @@ -1174,52 +1086,6 @@ ND4J_EXPORT void reSeedBuffer(Nd4jPointer *extraPointers, */ ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom); -/** - * Grid operations - */ - - - - -/** - * - * @param extras - * @param opTypeA - * @param opNumA - * @param opTypeB - * @param opNumB - * @param N - * @param dx - * @param xShapeInfo - * @param dy - * @param yShapeInfo - * @param dz - * @param zShapeInfo - * @param extraA - * @param extraB - * @param scalarA - * @param scalarB - */ - /* -ND4J_EXPORT void execMetaPredicateShape(Nd4jPointer *extras, - const int opTypeA, - const int opNumA, - const int opTypeB, - const int opNumB, - Nd4jLong N, - void *hX, Nd4jLong *hXShapeBuffer, - void *dX, Nd4jLong *dXShapeBuffer, - void *hY, Nd4jLong *hYShapeBuffer, - void *dY, Nd4jLong *dYShapeBuffer, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, - void *extraA, - void *extraB, - double scalarA, - double scalarB); - -*/ - } /** @@ -1561,11 +1427,10 @@ ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address); * @return */ ND4J_EXPORT void tear(Nd4jPointer *extraPointers, - void *x, Nd4jLong *xShapeInfo, - void *dx, Nd4jLong *dxShapeInfo, - Nd4jPointer *targets, Nd4jLong *zShapeInfo, - Nd4jLong *tadShapeInfo, - Nd4jLong *tadOffsets); + OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, + Nd4jPointer *targets, Nd4jLong *zShapeInfo, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffsets); ND4J_EXPORT Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold); ND4J_EXPORT void decodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo); @@ -1734,10 +1599,13 @@ typedef nd4j::graph::RandomGenerator OpaqueRandomGenerator; ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId); ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr); ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow); +ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride); ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); +ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); +ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); @@ -1765,6 +1633,28 @@ ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc); ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); +ND4J_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); +ND4J_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset); +ND4J_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); +ND4J_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes); +ND4J_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes); +ND4J_EXPORT void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT int dbLocality(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT int dbDeviceId(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId); +ND4J_EXPORT void dbTickHostRead(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbTickHostWrite(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbClose(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void deleteDataBuffer(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); + ND4J_EXPORT int binaryLevel(); ND4J_EXPORT int optimalLevel(); diff --git a/libnd4j/blas/cpu/GraphExecutioner.cpp b/libnd4j/blas/cpu/GraphExecutioner.cpp index ef45a3e0c..2190afbf1 100644 --- a/libnd4j/blas/cpu/GraphExecutioner.cpp +++ b/libnd4j/blas/cpu/GraphExecutioner.cpp @@ -104,7 +104,7 @@ namespace graph { if (node->id() == 13) nd4j_debug("",""); - // if true - this is special case: Graph-in-Graph. + // if true - this is special case: Graph-in-Graph. if (node->hasGraphEmbedded()) { auto embedded = node->getGraph(); @@ -128,12 +128,12 @@ namespace graph { int cnt = 0; for (Variable* v: *embedded->getPlaceholders()) { if (v->getName() != nullptr && v->getName()->size() > 0) { - + // trying symbolic lookup first if (variableSpace->hasVariable(v->getName())) { // symbolic feeder auto array = variableSpace->getVariable(v->getName())->getNDArray(); - auto vr = array->dup(); + auto vr = new NDArray(array->dup()); // deletables.push_back(vr); v->setNDArray(vr); } else { @@ -145,7 +145,7 @@ namespace graph { // if we're not using symbolic lookup - we'll use sequential approach then auto p = node->input()->at(cnt); auto array = variableSpace->getVariable(p)->getNDArray(); - auto vr = array->dup(); + auto vr = new NDArray(array->dup()); //deletables.push_back(vr); v->setNDArray(vr); } @@ -501,7 +501,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) } /** - * This method is provided for IPC: + * This method is provided for IPC: * 1) it accepts pointer to FlatBuffers buffer * 2) restores Graph from it * 3) Executes this Graph diff --git a/libnd4j/blas/cpu/NDArray.cpp b/libnd4j/blas/cpu/NDArray.cpp index dc9d09231..9bdf41a16 100644 --- a/libnd4j/blas/cpu/NDArray.cpp +++ b/libnd4j/blas/cpu/NDArray.cpp @@ -71,44 +71,41 @@ void NDArray::makeBothBuffersActual() const { } //////////////////////////////////////////////////////////////////////// template -void NDArray::fillAsTriangular(const float val, int lower, int upper, const char direction, NDArray* target) { +void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) { if (isS()) throw std::runtime_error("NDArray::fillArrayAsTriangular: you can't use this method on String array!"); - if(target == nullptr) - target = this; - - if(!isSameShape(target) && !(rankOf() == 1 && target->rankOf() == 2 && sizeAt(0) == target->sizeAt(0) && sizeAt(0) == target->sizeAt(1))) + if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) throw std::string("NDArray::fillArrayAsTriangular method: wrong shape of target array !"); if (direction == 'u') - lower = -target->sizeAt(-2); + lower = -target.sizeAt(-2); else if (direction == 'l') - upper = target->sizeAt(-1); + upper = target.sizeAt(-1); const T value = static_cast(val); const auto x = reinterpret_cast(getBuffer()); - auto z = reinterpret_cast(target->getBuffer()); + auto z = reinterpret_cast(target.getBuffer()); const int xRank = rankOf(); - const int zRank = target->rankOf(); + const int zRank = target.rankOf(); - const auto zLen = target->lengthOf(); + const auto zLen = target.lengthOf(); - const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target->getShapeInfo()); + const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target.getShapeInfo()); auto func = PRAGMA_THREADS_FOR { Nd4jLong coords[MAX_RANK]; for (auto i = start; i < stop; i += increment) { - shape::index2coords(i, target->getShapeInfo(), coords); - const auto zOffset = shape::getOffset(target->getShapeInfo(), coords); + shape::index2coords(i, target.getShapeInfo(), coords); + const auto zOffset = shape::getOffset(target.getShapeInfo(), coords); // if( (row + upper < col) || (row + lower > col) ) if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) z[zOffset] = value; - else if (this != target) { // when this and target are different arrays + else if (this != &target) { // when this and target are different arrays if (xRank != zRank) coords[0] = coords[1]; @@ -120,7 +117,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char samediff::Threads::parallel_for(func, 0, zLen); } -BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::setIdentity() { @@ -187,16 +184,16 @@ void NDArray::synchronize(const char* msg) const { // no-op } -void NDArray::prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { +void NDArray::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { // no-op } -void NDArray::registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList) { +void NDArray::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { // no-op } -void NDArray::preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { +void NDArray::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { // no-op } -void NDArray::registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList) { +void NDArray::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { // no-op } @@ -405,11 +402,11 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector& repeats) const { +NDArray NDArray::repeat(const int axis, const std::vector& repeats) const { - auto output = new NDArray('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); + NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); - BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, *output, repeats, axis), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, output, repeats, axis), LIBND4J_TYPES); return output; } diff --git a/libnd4j/blas/cpu/NDArrayLambda.hpp b/libnd4j/blas/cpu/NDArrayLambda.hpp index 6ce8e6823..86d798efc 100644 --- a/libnd4j/blas/cpu/NDArrayLambda.hpp +++ b/libnd4j/blas/cpu/NDArrayLambda.hpp @@ -2,35 +2,24 @@ template -void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; +void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::function& func, NDArray& target) { - if (second == nullptr) { - nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n",""); - throw std::runtime_error("second is null"); - } - - if (third == nullptr) { - nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n",""); - throw std::runtime_error("third is null"); - } if(dataType() != DataTypeUtils::fromT()) throw std::runtime_error("NDArray::applyTriplewiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType()) + if(dataType() != second.dataType() || dataType() != third.dataType() || dataType() != target.dataType()) throw std::runtime_error("NDArray::applyTriplewiseLambda method: bother four arrays (this, second, third, target) should have the same type !"); - if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { + if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); throw std::runtime_error("Shapes mismach"); } auto f = this->bufferAsT(); - auto s = second->bufferAsT(); - auto t = third->bufferAsT(); - auto z = target->bufferAsT(); + auto s = second.bufferAsT(); + auto t = third.bufferAsT(); + auto z = target.bufferAsT(); - if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) { + if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) @@ -44,8 +33,8 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std:: auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto tOffset = this->getOffset(e); - auto uOffset = second->getOffset(e); - auto vOffset = third->getOffset(e); + auto uOffset = second.getOffset(e); + auto vOffset = third.getOffset(e); f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); } @@ -57,9 +46,9 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std:: auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto tOffset = this->getOffset(e); - auto uOffset = second->getOffset(e); - auto vOffset = third->getOffset(e); - auto zOffset = target->getOffset(e); + auto uOffset = second.getOffset(e); + auto vOffset = third.getOffset(e); + auto zOffset = target.getOffset(e); z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); } @@ -69,46 +58,39 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std:: } } } -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; - - if (other == nullptr) { - nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n",""); - throw std::runtime_error("Other is null"); - } +void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target) { if(dataType() != DataTypeUtils::fromT()) throw std::runtime_error("NDArray::applyPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != other->dataType() || dataType() != target->dataType()) + if(dataType() != other.dataType() || dataType() != target.dataType()) throw std::runtime_error("NDArray::applyPairwiseLambda method: all three arrays (this, other, target) must have the same type !"); - if (this->lengthOf() != other->lengthOf()) { + if (this->lengthOf() != other.lengthOf()) { nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); throw std::runtime_error("Shapes mismach"); } auto f = this->bufferAsT(); - auto s = other->bufferAsT(); - auto z = target->bufferAsT(); + auto s = other.bufferAsT(); + auto z = target.bufferAsT(); - if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { + if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) @@ -122,7 +104,7 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::functiongetOffset(e); - auto yOffset = other->getOffset(e); + auto yOffset = other.getOffset(e); f[xOffset] = func(f[xOffset], s[yOffset]); } @@ -134,8 +116,8 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::functiongetOffset(e); - auto yOffset = other->getOffset(e); - auto zOffset = target->getOffset(e); + auto yOffset = other.getOffset(e); + auto zOffset = target.getOffset(e); z[zOffset] = func(f[xOffset], s[yOffset]); } @@ -145,35 +127,33 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyLambda(const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; +void NDArray::applyLambda(const std::function& func, NDArray& target) { if(dataType() != DataTypeUtils::fromT()) throw std::runtime_error("NDArray::applyLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target->dataType()) + if(dataType() != target.dataType()) throw std::runtime_error("NDArray::applyLambda method: types of this and target array should match !"); auto f = this->bufferAsT(); - auto z = target->bufferAsT(); + auto z = target.bufferAsT(); - if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { + if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) @@ -198,7 +178,7 @@ void NDArray::applyLambda(const std::function& func, NDArray* target) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto xOffset = this->getOffset(e); - auto zOffset = target->getOffset(e); + auto zOffset = target.getOffset(e); z[zOffset] = func(f[xOffset]); } @@ -208,35 +188,33 @@ void NDArray::applyLambda(const std::function& func, NDArray* target) { } } } -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedLambda(const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; +void NDArray::applyIndexedLambda(const std::function& func, NDArray& target) { if(dataType() != DataTypeUtils::fromT()) throw std::runtime_error("NDArray::applyIndexedLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target->dataType()) + if(dataType() != target.dataType()) throw std::runtime_error("NDArray::applyIndexedLambda method: types of this and target array should match !"); auto f = this->bufferAsT(); - auto z = target->bufferAsT(); + auto z = target.bufferAsT(); - if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { + if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) @@ -261,7 +239,7 @@ void NDArray::applyIndexedLambda(const std::function& func, NDAr auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto xOffset = this->getOffset(e); - auto zOffset = target->getOffset(e); + auto zOffset = target.getOffset(e); z[zOffset] = func(e, f[xOffset]); } @@ -271,44 +249,38 @@ void NDArray::applyIndexedLambda(const std::function& func, NDAr } } } -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; +void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target) { - if (other == nullptr) { - nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n",""); - throw std::runtime_error("Other is null"); - } if(dataType() != DataTypeUtils::fromT()) throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target->dataType()) + if(dataType() != target.dataType()) throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: types of this and target array should match !"); - if (this->lengthOf() != other->lengthOf()) { + if (this->lengthOf() != other.lengthOf()) { nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n",""); throw std::runtime_error("Shapes mismach"); } auto f = this->bufferAsT(); - auto s = other->bufferAsT(); - auto z = target->bufferAsT(); + auto s = other.bufferAsT(); + auto z = target.bufferAsT(); - if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { + if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) @@ -322,7 +294,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::functiongetOffset(e); - auto yOffset = other->getOffset(e); + auto yOffset = other.getOffset(e); f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); } @@ -334,8 +306,8 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::functiongetOffset(e); - auto yOffset = other->getOffset(e); - auto zOffset = target->getOffset(e); + auto yOffset = other.getOffset(e); + auto zOffset = target.getOffset(e); z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); } @@ -345,16 +317,16 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); \ No newline at end of file +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); \ No newline at end of file diff --git a/libnd4j/blas/cpu/NativeOpExecutioner.cpp b/libnd4j/blas/cpu/NativeOpExecutioner.cpp index 75a68c984..c155bd781 100644 --- a/libnd4j/blas/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/blas/cpu/NativeOpExecutioner.cpp @@ -398,7 +398,7 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc, }; auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); #endif } @@ -433,7 +433,7 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc, }; auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); } @@ -466,7 +466,7 @@ void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc, }; auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); } @@ -505,7 +505,7 @@ void NativeOpExecutioner::execReduceFloat(nd4j::LaunchContext *lc, const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads()); + samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// @@ -533,7 +533,7 @@ void NativeOpExecutioner::execReduceSame(nd4j::LaunchContext *lc, const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads()); + samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// @@ -561,7 +561,7 @@ void NativeOpExecutioner::execReduceBool(nd4j::LaunchContext *lc, const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads()); + samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// @@ -589,7 +589,7 @@ void NativeOpExecutioner::execReduceLong(nd4j::LaunchContext *lc, const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads()); + samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// @@ -876,7 +876,7 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, }; auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); #endif } @@ -913,7 +913,7 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, }; auto yLen = shape::length(hScalarShapeInfo); - samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min(yLen, nd4j::Environment::getInstance()->maxThreads())); + samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min(yLen, nd4j::Environment::getInstance()->maxMasterThreads())); #endif } @@ -947,7 +947,7 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, }; auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); } @@ -983,7 +983,7 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, }; auto yLen = shape::length(hScalarShapeInfo); - samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min(yLen, nd4j::Environment::getInstance()->maxThreads())); + samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min(yLen, nd4j::Environment::getInstance()->maxMasterThreads())); } //////////////////////////////////////////////////////////////////////// @@ -1015,7 +1015,7 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, }; auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(zLen / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); } @@ -1051,7 +1051,7 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, }; auto yLen = shape::length(hScalarShapeInfo); - samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min(yLen, nd4j::Environment::getInstance()->maxThreads())); + samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min(yLen, nd4j::Environment::getInstance()->maxMasterThreads())); } //////////////////////////////////////////////////////////////////////// @@ -1164,7 +1164,7 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc, BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES); }; - samediff::Threads::parallel_do(func, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_do(func, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// @@ -1186,7 +1186,7 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc, BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES); }; - samediff::Threads::parallel_do(func, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_do(func, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// @@ -1208,7 +1208,7 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc, BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES); }; - samediff::Threads::parallel_do(func, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_do(func, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// @@ -1230,7 +1230,7 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc, BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES); }; - samediff::Threads::parallel_do(func, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_do(func, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// @@ -1252,7 +1252,7 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc, BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES); }; - samediff::Threads::parallel_do(func, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads()))); + samediff::Threads::parallel_do(func, nd4j::math::nd4j_max(1, nd4j::math::nd4j_min(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index e790c05d0..1b1d22fbf 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -102,13 +102,11 @@ void setTADThreshold(int num) { */ void execIndexReduceScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { - NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -127,15 +125,12 @@ void execIndexReduceScalar(Nd4jPointer *extraPointers, * @param dimensionLength */ void execIndexReduce(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -144,17 +139,17 @@ void execIndexReduce(Nd4jPointer *extraPointers,int opNum, auto hTADShapeInfo = tadPack.primaryShapeInfo(); auto hTADOffsets = tadPack.primaryOffsets(); - auto hz = reinterpret_cast(hZ); + auto hz = reinterpret_cast(dbZ->primary()); NativeOpExecutioner::execIndexReduce(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, hz, hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -181,16 +176,12 @@ void execIndexReduce(Nd4jPointer *extraPointers,int opNum, */ void execBroadcast(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -205,16 +196,16 @@ void execBroadcast(Nd4jPointer *extraPointers, NativeOpExecutioner::execBroadcast(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hY, + dbY->primary(), hYShapeInfo, - dY, + dbY->special(), dYShapeInfo, - hZ, hZShapeInfo, - dZ, dZShapeInfo, + dbZ->primary(), hZShapeInfo, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, hTADOffsetsZ); } catch (std::exception &e) { @@ -225,17 +216,13 @@ void execBroadcast(Nd4jPointer *extraPointers, void execBroadcastBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -250,16 +237,16 @@ void execBroadcastBool(Nd4jPointer *extraPointers, NativeOpExecutioner::execBroadcastBool(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hY, + dbY->primary(), hYShapeInfo, - dY, + dbY->special(), dYShapeInfo, - hZ, hZShapeInfo, - dZ, dZShapeInfo, + dbZ->primary(), hZShapeInfo, + dbZ->special(), dZShapeInfo, extraParams, dimension, dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, @@ -285,27 +272,24 @@ void execBroadcastBool(Nd4jPointer *extraPointers, void execPairwiseTransform( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execPairwiseTransform(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hY, + dbY->primary(), hYShapeInfo, - dY, + dbY->special(), dYShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams); } catch (std::exception &e) { @@ -317,28 +301,25 @@ void execPairwiseTransform( void execPairwiseTransformBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execPairwiseBoolTransform(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hY, + dbY->primary(), hYShapeInfo, - dY, + dbY->special(), dYShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams); } catch (std::exception &e) { @@ -359,23 +340,21 @@ void execPairwiseTransformBool( void execReduceFloat( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { NativeOpExecutioner::execReduceFloatScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); @@ -386,23 +365,21 @@ void execReduceFloat( void execReduceSame( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { NativeOpExecutioner::execReduceSameScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); @@ -413,22 +390,20 @@ void execReduceSame( void execReduceBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { NativeOpExecutioner::execReduceBoolScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); @@ -439,22 +414,20 @@ void execReduceBool( void execReduceLong( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { NativeOpExecutioner::execReduceLongScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); @@ -473,15 +446,12 @@ void execReduceLong( */ void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -491,14 +461,14 @@ void execReduceFloat2(Nd4jPointer *extraPointers, auto hTADOffsets = tadPackX.primaryOffsets(); NativeOpExecutioner::execReduceFloat(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -512,15 +482,12 @@ void execReduceFloat2(Nd4jPointer *extraPointers, void execReduceBool2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -530,14 +497,14 @@ void execReduceBool2(Nd4jPointer *extraPointers, auto hTADOffsets = tadPack.primaryOffsets(); NativeOpExecutioner::execReduceBool(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -551,15 +518,12 @@ void execReduceBool2(Nd4jPointer *extraPointers, void execReduceSame2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -569,14 +533,14 @@ void execReduceSame2(Nd4jPointer *extraPointers, auto hTADOffsets = tadPack.primaryOffsets(); NativeOpExecutioner::execReduceSame(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -590,15 +554,12 @@ void execReduceSame2(Nd4jPointer *extraPointers, void execReduceLong2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -608,14 +569,14 @@ void execReduceLong2(Nd4jPointer *extraPointers, auto hTADOffsets = tadPack.primaryOffsets(); NativeOpExecutioner::execReduceLong(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -640,16 +601,13 @@ void execReduceLong2(Nd4jPointer *extraPointers, */ void execReduce3(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { - NativeOpExecutioner::execReduce3(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, - dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduce3(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, + dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -666,16 +624,13 @@ void execReduce3(Nd4jPointer *extraPointers, * @param hYShapeInfo */ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { - NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, - hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), + hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -696,24 +651,20 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, */ void execReduce3Tad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); if (extraPointers == nullptr || extraPointers[2] == 0) { - NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, hX, hXShapeInfo, dX, dXShapeInfo, - extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, + NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, + extraParams, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); } else { @@ -724,9 +675,9 @@ void execReduce3Tad(Nd4jPointer *extraPointers, auto hTADShapeInfo = tadPack.primaryShapeInfo(); auto hTADOffsets = tadPack.primaryOffsets(); - NativeOpExecutioner::execReduce3TAD(LaunchContext::defaultContext(), opNum, hX, hXShapeInfo, dX, - dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, - hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, + NativeOpExecutioner::execReduce3TAD(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), + hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets, nullptr, nullptr); } } catch (std::exception &e) { @@ -753,27 +704,24 @@ bool isBlasVersionMatches(int major, int minor, int build) { void execScalar( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hScalarShapeInfo, - void *dScalar, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams) { try { NativeOpExecutioner::execScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, - hScalar, + dbScalar->primary(), hScalarShapeInfo, - dScalar, + dbScalar->special(), dScalarShapeInfo, extraParams); } catch (std::exception &e) { @@ -785,27 +733,24 @@ void execScalar( void execScalarBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hScalarShapeInfo, - void *dScalar, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams) { try { NativeOpExecutioner::execScalarBool(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, - hScalar, + dbScalar->primary(), hScalarShapeInfo, - dScalar, + dbScalar->special(), dScalarShapeInfo, extraParams); } catch (std::exception &e) { @@ -823,23 +768,21 @@ void execScalarBool( */ void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected) { try { NativeOpExecutioner::execSummaryStatsScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, biasCorrected); } catch (std::exception &e) { @@ -858,23 +801,21 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers, */ void execSummaryStats(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected) { try { NativeOpExecutioner::execSummaryStats(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, biasCorrected); } catch (std::exception &e) { @@ -895,30 +836,27 @@ void execSummaryStats(Nd4jPointer *extraPointers, */ void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, bool biasCorrected, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); NativeOpExecutioner::execSummaryStats(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -944,21 +882,19 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers, void execTransformFloat( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execTransformFloat(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dZ, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams, nullptr, @@ -972,21 +908,19 @@ void execTransformFloat( void execTransformSame( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execTransformSame(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams, nullptr, @@ -1000,21 +934,19 @@ void execTransformSame( void execTransformBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execTransformBool(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams, nullptr, @@ -1028,21 +960,19 @@ void execTransformBool( void execTransformAny( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execTransformAny(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams, nullptr, @@ -1056,21 +986,19 @@ void execTransformAny( void execTransformStrict( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execTransformStrict(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams, nullptr, @@ -1083,27 +1011,23 @@ void execTransformStrict( void execReduce3All(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); - NativeOpExecutioner::execReduce3All(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, - hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, + NativeOpExecutioner::execReduce3All(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParamsVals, dbY->primary(), + hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); @@ -1398,10 +1322,8 @@ void pullRowsGeneric(void *vx, } void pullRows(Nd4jPointer *extraPointers, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, Nd4jLong n, Nd4jLong *indexes, Nd4jLong *tadShapeInfo, @@ -1411,7 +1333,7 @@ void pullRows(Nd4jPointer *extraPointers, try { auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, pullRowsGeneric, (hX, hXShapeInfo, hZ, hZShapeInfo, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, pullRowsGeneric, (dbX->primary(), hXShapeInfo, dbZ->primary(), hZShapeInfo, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1459,8 +1381,7 @@ void tearGeneric(void *vx, } void tear(Nd4jPointer *extraPointers, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, Nd4jPointer *targets, Nd4jLong *hZShapeInfo, Nd4jLong *tadShapeInfo, @@ -1468,7 +1389,7 @@ void tear(Nd4jPointer *extraPointers, try { auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, tearGeneric, (hX, hXShapeInfo, targets, hZShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, tearGeneric, (dbX->primary(), hXShapeInfo, targets, hZShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1653,35 +1574,31 @@ int getDevice() { void execScalarTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); NativeOpExecutioner::execScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, - hScalars, + dbScalars->primary(), hScalarShapeInfo, - dScalars, + dbScalars->special(), dScalarShapeInfo, dimension, shape::length(hDimensionShape), @@ -1697,35 +1614,31 @@ void execScalarTad(Nd4jPointer *extraPointers, void execScalarBoolTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); NativeOpExecutioner::execScalarBool(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, - hScalars, + dbScalars->primary(), hScalarShapeInfo, - dScalars, + dbScalars->special(), dScalarShapeInfo, dimension, dimensionLength, @@ -1809,11 +1722,10 @@ void execAggregateBatch(Nd4jPointer *extraPointers, void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(nullptr, opNum, state, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1823,15 +1735,12 @@ void execRandom(Nd4jPointer *extraPointers, void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1841,13 +1750,11 @@ void execRandom3(Nd4jPointer *extraPointers, void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -2717,25 +2624,25 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub switch (opCode) { case 0: - inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); break; case 1: - inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr); break; case 2: - inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr); break; case 3: - inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr); break; case 4: - inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr); break; case 5: - inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr); break; case 6: - inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr); break; default: continue; @@ -2863,6 +2770,15 @@ void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffe void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); } + +void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); +} + +void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); +} + void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) { ptr->setTArguments(arguments, numberOfArguments); } @@ -3104,10 +3020,13 @@ const char* lastErrorMessage() { return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage(); } +void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) { + ptr->setShapeFunctionOverride(reallyOverride); +} + int binaryLevel() { #ifdef CPU_FEATURES - #if defined(F_X64) return 1; #elif defined (F_AVX2) @@ -3173,6 +3092,102 @@ bool isOptimalRequirementsMet() { #endif } +OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { + try { + auto dtype = DataTypeUtils::fromInt(dataType); + return new nd4j::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype) , dtype, allocateBoth); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } +} + +Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->primary(); +} + +Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->special(); +} + +void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { + delete dataBuffer; +} + +void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) { + dataBuffer->setPrimary(primaryBuffer, numBytes); +} + +void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes) { + dataBuffer->setSpecial(specialBuffer, numBytes); +} + +void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->allocatePrimary(); +} + +void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->allocateSpecial(); +} + +void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { + try { + dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } +} + +OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) { + return new InteropDataBuffer(*dataBuffer, length, offset); +} + +void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->syncToSpecial(); +} + +void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->syncToPrimary(nullptr); +} + +void dbTickHostRead(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->readPrimary(); +} + +void dbTickHostWrite(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->writePrimary(); +} + +void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->readSpecial(); +} + +void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->writeSpecial(); +} + +void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { + dataBuffer->expand(elements); +} + +int dbLocality(OpaqueDataBuffer *dataBuffer) { + return 0; +} + +void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId) { + dataBuffer->setDeviceId(deviceId); +} + +int dbDeviceId(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->deviceId(); +} + +void dbClose(OpaqueDataBuffer *dataBuffer) { + dataBuffer->getDataBuffer()->close(); +} + BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong**, void**, Nd4jLong**, int, int*, Nd4jLong**, Nd4jLong**), LIBND4J_TYPES); diff --git a/libnd4j/blas/cuda/NDArray.cu b/libnd4j/blas/cuda/NDArray.cu index be90a22ae..81c8070b3 100644 --- a/libnd4j/blas/cuda/NDArray.cu +++ b/libnd4j/blas/cuda/NDArray.cu @@ -122,35 +122,32 @@ __global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xSha /////////////////////////////////////////////////////////////////// template -void NDArray::fillAsTriangular(const float val, int lower, int upper, const char direction, NDArray* target) { +void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) { if (isS()) throw std::runtime_error("NDArray::fillAsTriangular: you can't use this method on String array!"); - if(target == nullptr) - target = this; - - if(!isSameShape(target) && !(rankOf() == 1 && target->rankOf() == 2 && sizeAt(0) == target->sizeAt(0) && sizeAt(0) == target->sizeAt(1))) + if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) throw std::string("NDArray::fillAsTriangular method: wrong shape of target array !"); if (direction == 'u') - lower = -target->sizeAt(-2); + lower = -target.sizeAt(-2); else if (direction == 'l') - upper = target->sizeAt(-1); + upper = target.sizeAt(-1); const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (target->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(decltype(*target->getShapeInfo())) * target->rankOf() + 128; + const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(decltype(*target.getShapeInfo())) * target.rankOf() + 128; PointersManager manager(getContext(), "NDArray::fillAsTriangular"); - NDArray::prepareSpecialUse({target}, {this}); - fillAsTriangularCuda<<getCudaStream()>>>(getPlatformBuffer(), getPlatformShapeInfo(), target->getPlatformBuffer(), target->getPlatformShapeInfo(), static_cast(val), lower, upper); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + fillAsTriangularCuda<<getCudaStream()>>>(getPlatformBuffer(), getPlatformShapeInfo(), target.getPlatformBuffer(), target.getPlatformShapeInfo(), static_cast(val), lower, upper); + NDArray::registerSpecialUse({&target}, {this}); manager.synchronize(); } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// template @@ -239,7 +236,7 @@ void NDArray::synchronize(const char* msg) const { } //////////////////////////////////////////////////////////////////////// -void NDArray::prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { +void NDArray::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { for (const auto& a : readList) if(a != nullptr) @@ -255,7 +252,7 @@ void NDArray::prepareSpecialUse(const std::initializer_list& wri } //////////////////////////////////////////////////////////////////////// -void NDArray::registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList) { +void NDArray::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { for (const auto& p : readList) if(p != nullptr) @@ -267,7 +264,7 @@ void NDArray::registerSpecialUse(const std::initializer_list& wr } //////////////////////////////////////////////////////////////////////// -void NDArray::preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { +void NDArray::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { for (const auto& a : readList) if(a != nullptr) @@ -283,7 +280,7 @@ void NDArray::preparePrimaryUse(const std::initializer_list& wri } //////////////////////////////////////////////////////////////////////// -void NDArray::registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList) { +void NDArray::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { for (const auto& p : readList) if(p != nullptr) @@ -457,21 +454,21 @@ BUILD_DOUBLE_TEMPLATE(template void repeatCudaLauncher, (const int blocksPerGrid ////////////////////////////////////////////////////////////////////////// // create new array by repeating it the number of times given by repeats -NDArray* NDArray::repeat(const int axis, const std::vector& repeats) const { +NDArray NDArray::repeat(const int axis, const std::vector& repeats) const { - auto output = new NDArray('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); + NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output->rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector& repeats)"); const int* reps = reinterpret_cast(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); - prepareSpecialUse({output}, {this}); - BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES); - prepareSpecialUse({output}, {this}); + prepareSpecialUse({&output}, {this}); + BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES); + prepareSpecialUse({&output}, {this}); manager.synchronize(); diff --git a/libnd4j/blas/cuda/NDArrayLambda.hpp b/libnd4j/blas/cuda/NDArrayLambda.hpp index c27476bfb..15028dfaa 100644 --- a/libnd4j/blas/cuda/NDArrayLambda.hpp +++ b/libnd4j/blas/cuda/NDArrayLambda.hpp @@ -247,73 +247,73 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyLambda(Lambda func, NDArray* target) { - auto result = target == nullptr ? this : target; +void NDArray::applyLambda(Lambda func, NDArray& target) { + auto dtype = this->dataType(); - if (dtype != result->dataType()) + if (dtype != target.dataType()) throw std::runtime_error("NDArray::applyLambda X/Z data types must be the same"); - //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, result->dataType()); - prepareSpecialUse({result}, {this}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({result}, {this}); + //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyPairwiseLambda(const NDArray* other, Lambda func, NDArray* target) { - auto result = target == nullptr ? this : target; +void NDArray::applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target) { + auto dtype = this->dataType(); - if (dtype != result->dataType() || dtype != other->dataType()) + if (dtype != target.dataType() || dtype != other.dataType()) throw std::runtime_error("NDArray::applyPairwiseLambda X/Y/Z data types must be the same"); - //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, result->dataType()); + //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); - prepareSpecialUse({result}, {this, other}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({result}, {this, other}); + prepareSpecialUse({&target}, {this, &other}); + BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedLambda(Lambda func, NDArray* target) { - auto result = target == nullptr ? this : target; +void NDArray::applyIndexedLambda(Lambda func, NDArray& target) { + auto dtype = this->dataType(); - if (dtype != result->dataType()) + if (dtype != target.dataType()) throw std::runtime_error("NDArray::applyIndexedLambda X/Z data types must be the same"); - prepareSpecialUse({result}, {this}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({result}, {this}); + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedPairwiseLambda(NDArray* other, Lambda func, NDArray* target) { - auto result = target == nullptr ? this : target; +void NDArray::applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target) { + auto dtype = this->dataType(); - if (dtype != result->dataType() || dtype != other->dataType()) + if (dtype != target.dataType() || dtype != other.dataType()) throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same"); - prepareSpecialUse({result}, {this, other}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({result}, {this, other}); + prepareSpecialUse({&target}, {this, &other}); + BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, Lambda func, NDArray* target) { - auto result = target == nullptr ? this : target; +void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target) { + auto dtype = this->dataType(); - if (dtype != result->dataType() || dtype != second->dataType() || dtype != third->dataType()) + if (dtype != target.dataType() || dtype != second.dataType() || dtype != third.dataType()) throw std::runtime_error("NDArray::applyTriplewiseLambda X/Y/Z data types must be the same"); - prepareSpecialUse({result}, {this, second, third}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second->specialBuffer(), second->specialShapeInfo(), third->specialBuffer(), third->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({result}, {this, second, third}); + prepareSpecialUse({&target}, {this, &second, &third}); + BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second.specialBuffer(), second.specialShapeInfo(), third.specialBuffer(), third.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &second, &third}); } diff --git a/libnd4j/blas/cuda/NativeOpExecutioner.cu b/libnd4j/blas/cuda/NativeOpExecutioner.cu index 1f074f39b..1e0685dc4 100644 --- a/libnd4j/blas/cuda/NativeOpExecutioner.cu +++ b/libnd4j/blas/cuda/NativeOpExecutioner.cu @@ -488,7 +488,7 @@ void NativeOpExecutioner::execReduceSame(nd4j::LaunchContext *lc, throw datatype_exception::build("NativeOpExecutioner::execReduceSame requires both X & Z operands to have same type", xType, zType); auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks, 256, 8192); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 8192); BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES); @@ -523,7 +523,7 @@ void NativeOpExecutioner::execReduceLong(nd4j::LaunchContext *lc, auto xRank = shape::rank(hXShapeInfo); auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks, 256, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, LONG_TYPES); @@ -559,7 +559,7 @@ void NativeOpExecutioner::execReduceBool(nd4j::LaunchContext *lc, auto xRank = shape::rank(hXShapeInfo); auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks, 256, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES); @@ -601,7 +601,7 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks, 256, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); if (zType != nd4j::DataType::INT64 && zType != nd4j::DataType::INT32) throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT32/INT64 type", zType); @@ -647,7 +647,7 @@ void NativeOpExecutioner::execReduceFloat(nd4j::LaunchContext *lc, auto xRank = shape::rank(hXShapeInfo); auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks, 256, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES); @@ -684,7 +684,7 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, auto xLength = shape::length(hXShapeInfo); auto blockWidth = 256; auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); if (nd4j::Environment::getInstance()->isDebugAndVerbose() && launchDims.x == 1) printf("AF1 opNum:[%i]\n", opNum); @@ -734,7 +734,7 @@ void NativeOpExecutioner::execReduceFloatScalar(nd4j::LaunchContext *lc, auto xLength = shape::length(hXShapeInfo); auto blockWidth = 256; auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceScalar(launchDims, stream, opNum, dX,dXShapeInfo, hXShapeInfo, extraParams, dZ,dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES); @@ -766,7 +766,7 @@ void NativeOpExecutioner::execReduceBoolScalar(nd4j::LaunchContext *lc, auto xLength = shape::length(hXShapeInfo); auto blockWidth = 256; auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, BOOL_TYPES); @@ -797,7 +797,7 @@ void NativeOpExecutioner::execReduceSameScalar(nd4j::LaunchContext *lc, auto xLength = shape::length(hXShapeInfo); auto blockWidth = 256; auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES); @@ -828,7 +828,7 @@ void NativeOpExecutioner::execReduceLongScalar(nd4j::LaunchContext *lc, auto xLength = shape::length(hXShapeInfo); auto blockWidth = 256; auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, LONG_TYPES); @@ -1085,7 +1085,7 @@ void NativeOpExecutioner::execReduce3(nd4j::LaunchContext *lc, auto blockWidth = 256; auto numBlocks = CudaLaunchHelper::getReductionBlocks(shape::length(hXShapeInfo), blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); if (xType != yType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Y operand to have X type", xType, yType); @@ -1135,7 +1135,7 @@ void NativeOpExecutioner::execReduce3(nd4j::LaunchContext *lc, auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks, 256, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, dX, dXShapeInfo, @@ -1177,7 +1177,7 @@ void NativeOpExecutioner::execReduce3Scalar(nd4j::LaunchContext *lc, auto xLength = shape::length(hXShapeInfo); auto blockWidth = 256; auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); if (xType != yType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execReduce3Scalar requires Y operand to have X type", xType, yType); @@ -1595,7 +1595,7 @@ void NativeOpExecutioner::execReduce3TAD(nd4j::LaunchContext *lc, throw nd4j::datatype_exception::build("NativeOpExecutioner::execReduce3TAD requires Z operand to have floating point data type", zType); auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks, 256, 32768); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, tadShapeInfo, tadOffsets, yTadShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES); diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index c8cb3a616..419cadef5 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -229,17 +229,19 @@ public: void execPairwiseTransform( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execPairwiseTransform(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams); + NativeOpExecutioner::execPairwiseTransform(&lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -249,17 +251,21 @@ void execPairwiseTransform( Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execPairwiseTransformBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execPairwiseBoolTransform(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, - dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams); + NativeOpExecutioner::execPairwiseBoolTransform(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -269,16 +275,21 @@ void execPairwiseTransformBool(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStatsScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo, biasCorrected); + NativeOpExecutioner::execSummaryStatsScalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -288,22 +299,16 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execBroadcastBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + void *extraParams, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - //Nd4jLong *tadOnlyShapeInfo = reinterpret_cast(extraPointers[0]); - //Nd4jLong *tadOffsets = reinterpret_cast(extraPointers[1]); - //Nd4jLong *tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[2]); - //Nd4jLong *tadOffsetsZ = reinterpret_cast(extraPointers[3]); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - auto dimension = reinterpret_cast(dDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); @@ -313,10 +318,15 @@ void execBroadcastBool(Nd4jPointer *extraPointers, auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, dimension, - dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, - tadOffsetsZ); + NativeOpExecutioner::execBroadcastBool(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -338,16 +348,15 @@ void execBroadcastBool(Nd4jPointer *extraPointers, void execBroadcast( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); cudaStream_t *stream = reinterpret_cast(extraPointers[1]); @@ -362,13 +371,15 @@ void execBroadcast( auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3 opNum:[%i]\n", opNum); - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execBroadcast(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + NativeOpExecutioner::execBroadcast(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -388,15 +399,19 @@ void execBroadcast( //////////////////////////////////////////////////////////////////////// void execReduceFloat(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceFloatScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduceFloatScalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -406,15 +421,19 @@ void execReduceFloat(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceSame(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceSameScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduceSameScalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -424,25 +443,30 @@ void execReduceSame(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceSame2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceSame(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), - tadPack.specialOffsets()); + NativeOpExecutioner::execReduceSame(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -452,25 +476,30 @@ void execReduceSame2(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceLong2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceLong(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), - tadPack.specialOffsets()); + NativeOpExecutioner::execReduceLong(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -480,19 +509,16 @@ void execReduceLong2(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceLong(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto stream = reinterpret_cast(extraPointers[1]); auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("LF7 opNum:[%i]\n", opNum); - auto reductionPointer = reinterpret_cast(extraPointers[4]); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); @@ -507,11 +533,15 @@ void execReduceLong(Nd4jPointer *extraPointers, dim3 launchDims(numBlocks, blockWidth, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, - ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, - dZ, dZShapeInfo, hXShapeInfo, nullptr, 0, reductionPointer, - dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES); + ::execReduceScalar(launchDims, stream, opNum, + dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), hXShapeInfo, + extraParams, + dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), hXShapeInfo, + nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES); nd4j::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed"); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -521,25 +551,30 @@ void execReduceLong(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceBool2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), - tadPack.specialOffsets()); + NativeOpExecutioner::execReduceBool(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -549,19 +584,16 @@ void execReduceBool2(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto stream = reinterpret_cast(extraPointers[1]); auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("BF7 opNum:[%i]\n", opNum); - auto reductionPointer = reinterpret_cast(extraPointers[4]); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); @@ -576,11 +608,15 @@ void execReduceBool(Nd4jPointer *extraPointers, dim3 launchDims(numBlocks, blockWidth, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, - ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, - dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, - dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES); + ::execReduceScalar(launchDims, stream, opNum, + dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), hXShapeInfo, + extraParams, + dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), hZShapeInfo, + nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES); nd4j::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed"); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -601,25 +637,30 @@ void execReduceBool(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execIndexReduce(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execIndexReduce(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), - tadPack.specialOffsets()); + NativeOpExecutioner::execIndexReduce(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + (int *) dbDimension->special(), dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -638,25 +679,30 @@ void execIndexReduce(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), - tadPack.specialOffsets()); + NativeOpExecutioner::execReduceFloat(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -674,15 +720,19 @@ void execReduceFloat2(Nd4jPointer *extraPointers, void execIndexReduceScalar( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo){ + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo){ try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execIndexReduceScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execIndexReduceScalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -691,18 +741,23 @@ void execIndexReduceScalar( //////////////////////////////////////////////////////////////////////// void execTransformSame(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformSame(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + NativeOpExecutioner::execTransformSame(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -711,18 +766,23 @@ void execTransformSame(Nd4jPointer *extraPointers,int opNum, //////////////////////////////////////////////////////////////////////// void execTransformBool(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + NativeOpExecutioner::execTransformBool(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -731,19 +791,24 @@ void execTransformBool(Nd4jPointer *extraPointers,int opNum, //////////////////////////////////////////////////////////////////////// void execTransformAny(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *extraParams) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto stream = reinterpret_cast(extraPointers[1]); auto streamSpecial = reinterpret_cast(extraPointers[4]); LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], reinterpret_cast(extraPointers[6])); - NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraParams, nullptr, nullptr); + NativeOpExecutioner::execTransformAny(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + nullptr, nullptr); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -752,18 +817,23 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum, //////////////////////////////////////////////////////////////////////// void execTransformStrict(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *extraParams) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformStrict(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + NativeOpExecutioner::execTransformStrict(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -772,18 +842,23 @@ void execTransformStrict(Nd4jPointer *extraPointers,int opNum, //////////////////////////////////////////////////////////////////////// void execTransformFloat(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + NativeOpExecutioner::execTransformFloat(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1094,7 +1169,43 @@ Nd4jLong getDeviceTotalMemory(int device) { } int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - return memcpyAsync(dst, src, size, flags, reserved); + cudaMemcpyKind kind; + + switch (flags) { + case 0: { + kind = cudaMemcpyHostToHost; + } + break; + case 1: { + kind = cudaMemcpyHostToDevice; + } + break; + case 2: { + kind = cudaMemcpyDeviceToHost; + } + break; + case 3: { + kind = cudaMemcpyDeviceToDevice; + } + break; + default: { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("UNDEFNED MEMCPY"); + return 0; + } + } + + auto dZ = cudaMemcpy(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind); + if (dZ != 0) { + printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); + fflush(stdout); + fflush(stderr); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpy failed"); + return 0; + } + + return 1; } int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { @@ -1131,11 +1242,12 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j auto dZ = cudaMemcpyAsync(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind, *pStream); //auto dZ = cudaMemcpy(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind); if (dZ != 0) { - printf("Failed on [%lu] -> [%lu], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); + printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); fflush(stdout); fflush(stderr); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpyAsync failed"); + return 0; } return 1; @@ -1348,10 +1460,8 @@ Nd4jPointer getConstantSpace() { } void pullRows(Nd4jPointer *extraPointers, - void *x, Nd4jLong *xShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *z, Nd4jLong *zShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dZShapeInfo, Nd4jLong n, Nd4jLong *indexes, Nd4jLong *tadShapeInfo, @@ -1359,14 +1469,18 @@ void pullRows(Nd4jPointer *extraPointers, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); dim3 launchDims(64, 256, 1024); auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); BUILD_SINGLE_SELECTOR(xType, pullRowsKernelGeneric, - (launchDims, stream, dX, dZ, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), + (launchDims, stream, dbX->special(), dbZ->special(), n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); DEBUG_KERNEL(stream, -1); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1500,16 +1614,21 @@ void setTADThreshold(int num) { //////////////////////////////////////////////////////////////////////// void execSummaryStats(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStats(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo, biasCorrected); + NativeOpExecutioner::execSummaryStats(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1519,22 +1638,29 @@ void execSummaryStats(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, bool biasCorrected, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbDimension}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStats(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, - tadOffsets, biasCorrected); + NativeOpExecutioner::execSummaryStats(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + reinterpret_cast(dbDimension->special()), dimensionLength, + tadShapeInfo, tadOffsets, + biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbDimension}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1544,17 +1670,21 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduce3(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduce3(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1564,22 +1694,22 @@ void execReduce3(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduce3Tad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); auto tadLength = shape::length(tadPack.primaryShapeInfo()); auto yLength = shape::length(hYShapeInfo); @@ -1589,16 +1719,23 @@ void execReduce3Tad(Nd4jPointer *extraPointers, if (tadLength == yLength || tadLength == xLength) { // nd4j_printf("== way\n",""); - NativeOpExecutioner::execReduce3(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, - dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); + NativeOpExecutioner::execReduce3(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); } else - NativeOpExecutioner::execReduce3TAD(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, - hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, - dimension, dimensionLength, tadOnlyShapeInfo, yTadOffsets, - yTadOnlyShapeInfo, yTadOffsets); + NativeOpExecutioner::execReduce3TAD(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadOnlyShapeInfo, yTadOffsets, yTadOnlyShapeInfo, yTadOffsets); + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1607,17 +1744,21 @@ void execReduce3Tad(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + void *extraParams, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3Scalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, - hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduce3Scalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1627,18 +1768,21 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, //////////////////////////////////////////////////////////////////////// void execScalarBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hScalarShapeInfo, - void *dScalar, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalarBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, hScalar, hScalarShapeInfo, dScalar, dScalarShapeInfo, - extraParams); + NativeOpExecutioner::execScalarBool(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1648,25 +1792,30 @@ void execScalarBool(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execScalarBoolTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalarBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, hScalars, hScalarShapeInfo, dScalars, dScalarShapeInfo, - dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, - tadOffsetsZ); + NativeOpExecutioner::execScalarBool(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dbScalars->primary(), hScalarShapeInfo, dbScalars->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT(), + dimension, dimensionLength, + tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1676,17 +1825,21 @@ void execScalarBoolTad(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hScalarShapeInfo, - void *dScalar, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, - hScalar, hScalarShapeInfo, dScalar, dScalarShapeInfo, extraParams); + NativeOpExecutioner::execScalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1696,19 +1849,18 @@ void execScalar(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execScalarTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); cudaStream_t *stream = reinterpret_cast(extraPointers[1]); @@ -1725,10 +1877,12 @@ void execScalarTad(Nd4jPointer *extraPointers, #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), dbScalars->special(), extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); #endif DEBUG_KERNEL(stream, opNum); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1777,12 +1931,17 @@ void execAggregateBatch(Nd4jPointer *extraPointers, void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1791,15 +1950,19 @@ void execRandom(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1808,17 +1971,21 @@ void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, //////////////////////////////////////////////////////////////////////// void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1924,21 +2091,24 @@ Nd4jPointer pointerForAddress(Nd4jLong address) { } void tear(Nd4jPointer *extras, - void *x, Nd4jLong *xShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dXShapeInfo, Nd4jPointer *targets, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { try { + InteropDataBuffer::prepareSpecialUse({}, {dbX}); + cudaStream_t *stream = reinterpret_cast(extras[1]); dim3 launchDims(512, 512, 512); auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); BUILD_SINGLE_SELECTOR(xType, tearKernelGeneric, - (launchDims, stream, dX, dXShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets), + (launchDims, stream, dbX->special(), dXShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); nd4j::DebugHelper::checkErrorCode(stream, "tearFloat(...) failed"); + + InteropDataBuffer::registerSpecialUse({}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -2100,25 +2270,30 @@ void decodeThreshold(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, //////////////////////////////////////////////////////////////////////// void execReduce3All(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY, dbDimension}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3All(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, - hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, - dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); + NativeOpExecutioner::execReduce3All(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParamsVals, + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + reinterpret_cast(dbDimension->special()), dimensionLength, + xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -3384,34 +3559,56 @@ nd4j::graph::Context* createGraphContext(int nodeId) { nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) { return &ptr->randomGenerator(); } + void markGraphContextInplace(nd4j::graph::Context* ptr, bool reallyInplace) { ptr->markInplace(reallyInplace); } + void setGraphContextCudaContext(nd4j::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) { ptr->setCudaContext(stream, reductionPointer, allocationPointer); } + void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); } + void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); } + +void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); +} + +void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); +} + void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) { ptr->setTArguments(arguments, numberOfArguments); } + void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) { ptr->setIArguments(arguments, numberOfArguments); } + void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) { ptr->setBArguments(arguments, numberOfArguments); } + void deleteGraphContext(nd4j::graph::Context* ptr) { delete ptr; } nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); + try { + return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) { @@ -3559,6 +3756,10 @@ const char* lastErrorMessage() { return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage(); } +void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) { + ptr->setShapeFunctionOverride(reallyOverride); +} + int binaryLevel() { return 0; } @@ -3577,4 +3778,108 @@ bool isOptimalRequirementsMet() { void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) { ptr->allowHelpers(reallyAllow); +} + +OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { + try { + auto dtype = DataTypeUtils::fromInt(dataType); + return new nd4j::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype), dtype, allocateBoth); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } +} + +Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->primary(); +} + +Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->special(); +} + +void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { + delete dataBuffer; +} + +void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) { + dataBuffer->setPrimary(primaryBuffer, numBytes); +} + +void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes) { + dataBuffer->setSpecial(specialBuffer, numBytes); +} + +void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->allocatePrimary(); +} + +void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->allocateSpecial(); +} + +void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { + try { + dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } +} + +OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) { + return new InteropDataBuffer(*dataBuffer, length, offset); +} + +void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->syncToSpecial(); +} + +void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->syncToPrimary(nullptr); +} + +void dbTickHostRead(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->readPrimary(); +} + +void dbTickHostWrite(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->writePrimary(); +} + +void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->readSpecial(); +} + +void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->writeSpecial(); +} + +void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { + dataBuffer->expand(elements); +} + +void dbClose(OpaqueDataBuffer *dataBuffer) { + dataBuffer->getDataBuffer()->close(); +} + +int dbDeviceId(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->deviceId(); +} + +void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId) { + dataBuffer->setDeviceId(deviceId); +} + +int dbLocality(OpaqueDataBuffer *dataBuffer) { + auto p = dataBuffer->dataBuffer()->isPrimaryActual(); + auto d = dataBuffer->dataBuffer()->isSpecialActual(); + + if (p && d) + return 0; + else if (p) + return -1; + else + return 1; } \ No newline at end of file diff --git a/libnd4j/buildnativeoperations.sh b/libnd4j/buildnativeoperations.sh index 119b04f93..a8b45e918 100755 --- a/libnd4j/buildnativeoperations.sh +++ b/libnd4j/buildnativeoperations.sh @@ -489,6 +489,7 @@ mkbuilddir() { cd "blasbuild/$CHIP" } +HELPERS="" if [ "$HELPER" == "" ]; then echo "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" echo "!! !!" @@ -503,6 +504,14 @@ if [ "$HELPER" == "" ]; then echo "!! !!" echo "!! !!" echo "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" +else + # if helpers were defined, we'll propagate them to CMake + IFS=',' + read -ra HLP <<< "$HELPER" + for i in "${HLP[@]}"; do + HELPERS="${HELPERS} -DHELPERS_$i=true" + done + IFS=' ' fi echo PACKAGING = "${PACKAGING}" @@ -519,10 +528,10 @@ echo MINIFIER = "${MINIFIER_ARG}" echo TESTS = "${TESTS_ARG}" echo NAME = "${NAME_ARG}" echo OPENBLAS_PATH = "$OPENBLAS_PATH" -echo HELPERS = "$HELPER" +echo HELPERS = "$HELPERS" mkbuilddir pwd -eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DHELPERS_"$HELPER"=true "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../.. +eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" $HELPERS "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../.. if [ "$PARALLEL" == "true" ]; then MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ" fi diff --git a/libnd4j/include/array/ArrayOptions.h b/libnd4j/include/array/ArrayOptions.h index a753be1bf..484228fb7 100644 --- a/libnd4j/include/array/ArrayOptions.h +++ b/libnd4j/include/array/ArrayOptions.h @@ -34,10 +34,12 @@ #define ARRAY_SPARSE 2 #define ARRAY_COMPRESSED 4 #define ARRAY_EMPTY 8 +#define ARRAY_RAGGED 16 -#define ARRAY_CSR 16 -#define ARRAY_CSC 32 -#define ARRAY_COO 64 + +#define ARRAY_CSR 32 +#define ARRAY_CSC 64 +#define ARRAY_COO 128 // complex values #define ARRAY_COMPLEX 512 @@ -72,8 +74,10 @@ // boolean values #define ARRAY_BOOL 524288 -// utf-8 values -#define ARRAY_STRING 1048576 +// UTF values +#define ARRAY_UTF8 1048576 +#define ARRAY_UTF16 4194304 +#define ARRAY_UTF32 16777216 // flag for extras #define ARRAY_EXTRAS 2097152 @@ -173,8 +177,12 @@ namespace nd4j { return nd4j::DataType ::UINT32; else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) return nd4j::DataType ::UINT64; - else if (hasPropertyBitSet(shapeInfo, ARRAY_STRING)) + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) return nd4j::DataType ::UTF8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) + return nd4j::DataType ::UTF16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) + return nd4j::DataType ::UTF32; else { //shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast(shapeInfo)); #ifndef __CUDA_ARCH__ @@ -190,8 +198,12 @@ namespace nd4j { return nd4j::DataType::INT32; else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) return nd4j::DataType::INT64; - else if (hasPropertyBitSet(shapeInfo, ARRAY_STRING)) + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) return nd4j::DataType::UTF8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) + return nd4j::DataType::UTF16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) + return nd4j::DataType::UTF32; else { //shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast(shapeInfo)); #ifndef __CUDA_ARCH__ @@ -224,6 +236,8 @@ namespace nd4j { return ArrayType::COMPRESSED; else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY)) return ArrayType::EMPTY; + else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED)) + return ArrayType::RAGGED; else // by default we return DENSE type here return ArrayType::DENSE; } @@ -333,7 +347,13 @@ namespace nd4j { setPropertyBit(shapeInfo, ARRAY_LONG); break; case nd4j::DataType::UTF8: - setPropertyBit(shapeInfo, ARRAY_STRING); + setPropertyBit(shapeInfo, ARRAY_UTF8); + break; + case nd4j::DataType::UTF16: + setPropertyBit(shapeInfo, ARRAY_UTF16); + break; + case nd4j::DataType::UTF32: + setPropertyBit(shapeInfo, ARRAY_UTF32); break; default: #ifndef __CUDA_ARCH__ diff --git a/libnd4j/include/array/ArrayType.h b/libnd4j/include/array/ArrayType.h index 2300bf841..d4d6c9729 100644 --- a/libnd4j/include/array/ArrayType.h +++ b/libnd4j/include/array/ArrayType.h @@ -27,6 +27,7 @@ namespace nd4j { SPARSE = 2, COMPRESSED = 3, EMPTY = 4, + RAGGED = 5, }; } diff --git a/libnd4j/include/array/ConstantDescriptor.h b/libnd4j/include/array/ConstantDescriptor.h index f2f2f46a6..f32c1c8bf 100644 --- a/libnd4j/include/array/ConstantDescriptor.h +++ b/libnd4j/include/array/ConstantDescriptor.h @@ -22,7 +22,7 @@ #define DEV_TESTS_CONSTANTDESCRIPTOR_H #include -#include +#include #include #include #include diff --git a/libnd4j/include/array/DataBuffer.h b/libnd4j/include/array/DataBuffer.h index 034f16a25..cd27c20b8 100644 --- a/libnd4j/include/array/DataBuffer.h +++ b/libnd4j/include/array/DataBuffer.h @@ -36,13 +36,14 @@ class ND4J_EXPORT DataBuffer { private: - void* _primaryBuffer; - void* _specialBuffer; - size_t _lenInBytes; + void* _primaryBuffer = nullptr; + void* _specialBuffer = nullptr; + size_t _lenInBytes = 0; DataType _dataType; - memory::Workspace* _workspace; + memory::Workspace* _workspace = nullptr; bool _isOwnerPrimary; bool _isOwnerSpecial; + std::atomic _deviceId; #ifdef __CUDABLAS__ mutable std::atomic _counter; @@ -52,51 +53,52 @@ class ND4J_EXPORT DataBuffer { mutable std::atomic _readSpecial; #endif - void setCountersToZero(); - void copyCounters(const DataBuffer& other); - void deleteSpecial(); - FORCEINLINE void deletePrimary(); - FORCEINLINE void deleteBuffers(); - FORCEINLINE void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false); - void allocateBuffers(const bool allocBoth = false); - void setSpecial(void* special, const bool isOwnerSpecial); - void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0); + void setCountersToZero(); + void copyCounters(const DataBuffer& other); + void deleteSpecial(); + void deletePrimary(); + void deleteBuffers(); + void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false); + void allocateBuffers(const bool allocBoth = false); + void setSpecial(void* special, const bool isOwnerSpecial); + void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0); public: - FORCEINLINE DataBuffer(void* primary, void* special, + DataBuffer(void* primary, void* special, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary = false, const bool isOwnerSpecial = false, memory::Workspace* workspace = nullptr); - FORCEINLINE DataBuffer(void* primary, + DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary = false, memory::Workspace* workspace = nullptr); - FORCEINLINE DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer + DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace = nullptr); - FORCEINLINE DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false); + DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false); - FORCEINLINE DataBuffer(const DataBuffer& other); - FORCEINLINE DataBuffer(DataBuffer&& other); - FORCEINLINE explicit DataBuffer(); - FORCEINLINE ~DataBuffer(); + DataBuffer(const DataBuffer& other); + DataBuffer(DataBuffer&& other); + explicit DataBuffer(); + ~DataBuffer(); - FORCEINLINE DataBuffer& operator=(const DataBuffer& other); - FORCEINLINE DataBuffer& operator=(DataBuffer&& other) noexcept; + DataBuffer& operator=(const DataBuffer& other); + DataBuffer& operator=(DataBuffer&& other) noexcept; - FORCEINLINE DataType getDataType(); - FORCEINLINE size_t getLenInBytes() const; + DataType getDataType(); + void setDataType(DataType dataType); + size_t getLenInBytes() const; - FORCEINLINE void* primary(); - FORCEINLINE void* special(); + void* primary(); + void* special(); - FORCEINLINE void allocatePrimary(); - void allocateSpecial(); + void allocatePrimary(); + void allocateSpecial(); void writePrimary() const; void writeSpecial() const; @@ -105,6 +107,10 @@ class ND4J_EXPORT DataBuffer { bool isPrimaryActual() const; bool isSpecialActual() const; + void expand(const uint64_t size); + + int deviceId() const; + void setDeviceId(int deviceId); void migrate(); template FORCEINLINE T* primaryAsT(); @@ -118,256 +124,28 @@ class ND4J_EXPORT DataBuffer { void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0); static void memcpy(const DataBuffer &dst, const DataBuffer &src); + + void setPrimaryBuffer(void *buffer, size_t length); + void setSpecialBuffer(void *buffer, size_t length); + + /** + * This method deletes buffers, if we're owners + */ + void close(); }; - - - ///// IMLEMENTATION OF INLINE METHODS ///// - //////////////////////////////////////////////////////////////////////// -// default constructor -DataBuffer::DataBuffer() { - - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - _lenInBytes = 0; - _dataType = INT8; - _workspace = nullptr; - _isOwnerPrimary = false; - _isOwnerSpecial = false; - - setCountersToZero(); -} - -//////////////////////////////////////////////////////////////////////// -// copy constructor -DataBuffer::DataBuffer(const DataBuffer &other) { - - throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!"); - - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - - setCountersToZero(); - - allocateBuffers(); - copyBufferFrom(other); -} - -//////////////////////////////////////////////////////////////////////// -DataBuffer::DataBuffer(void* primary, void* special, - const size_t lenInBytes, const DataType dataType, - const bool isOwnerPrimary, const bool isOwnerSpecial, - memory::Workspace* workspace) { - - if (primary == nullptr && special == nullptr) - throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !"); - - _primaryBuffer = primary; - _specialBuffer = special; - _lenInBytes = lenInBytes; - _dataType = dataType; - _workspace = workspace; - _isOwnerPrimary = isOwnerPrimary; - _isOwnerSpecial = isOwnerSpecial; - - setCountersToZero(); - - if(primary != nullptr) - readPrimary(); - if(special != nullptr) - readSpecial(); -} - -//////////////////////////////////////////////////////////////////////// -DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace): - DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) { - - syncToSpecial(true); -} - -//////////////////////////////////////////////////////////////////////// -// copies data from hostBuffer to own memory buffer -DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) { - - if (hostBuffer == nullptr) - throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !"); - if (lenInBytes == 0) - throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !"); - - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - _lenInBytes = lenInBytes; - _dataType = dataType; - _workspace = workspace; - - setCountersToZero(); - - allocateBuffers(); - - copyBufferFromHost(hostBuffer, lenInBytes); -} - -//////////////////////////////////////////////////////////////////////// -DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) { - - _dataType = dataType; - _workspace = workspace; - _lenInBytes = lenInBytes; - - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - - setCountersToZero(); - - if(lenInBytes != 0) { - allocateBuffers(allocBoth); - writeSpecial(); + template + T* DataBuffer::primaryAsT() { + return reinterpret_cast(_primaryBuffer); } -} //////////////////////////////////////////////////////////////////////// -// move constructor -DataBuffer::DataBuffer(DataBuffer&& other) { - - _primaryBuffer = other._primaryBuffer; - _specialBuffer = other._specialBuffer; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - _isOwnerPrimary = other._isOwnerPrimary; - _isOwnerSpecial = other._isOwnerSpecial; - - copyCounters(other); - - other._primaryBuffer = other._specialBuffer = nullptr; - other.setAllocFlags(false, false); - other._lenInBytes = 0; -} - -//////////////////////////////////////////////////////////////////////// -// assignment operator -DataBuffer& DataBuffer::operator=(const DataBuffer& other) { - - if (this == &other) - return *this; - - deleteBuffers(); - - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - - allocateBuffers(); - copyBufferFrom(other); - - return *this; -} - -//////////////////////////////////////////////////////////////////////// -// move assignment operator -DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept { - - if (this == &other) - return *this; - - deleteBuffers(); - - _primaryBuffer = other._primaryBuffer; - _specialBuffer = other._specialBuffer; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - _isOwnerPrimary = other._isOwnerPrimary; - _isOwnerSpecial = other._isOwnerSpecial; - - copyCounters(other); - - other._primaryBuffer = other._specialBuffer = nullptr; - other.setAllocFlags(false, false); - other._lenInBytes = 0; - - return *this; -} - -//////////////////////////////////////////////////////////////////////// -void* DataBuffer::primary() { - return _primaryBuffer; -} - -//////////////////////////////////////////////////////////////////////// -void* DataBuffer::special() { - return _specialBuffer; -} - -//////////////////////////////////////////////////////////////////////// -DataType DataBuffer::getDataType() { - return _dataType; -} - -//////////////////////////////////////////////////////////////////////// -size_t DataBuffer::getLenInBytes() const { - return _lenInBytes; -} - -//////////////////////////////////////////////////////////////////////// -template -T* DataBuffer::primaryAsT() { - return reinterpret_cast(_primaryBuffer); -} - -//////////////////////////////////////////////////////////////////////// -template -T* DataBuffer::specialAsT() { - return reinterpret_cast(_specialBuffer); -} - -//////////////////////////////////////////////////////////////////////// -void DataBuffer::allocatePrimary() { - - if (_primaryBuffer == nullptr && getLenInBytes() > 0) { - ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t); - _isOwnerPrimary = true; + template + T* DataBuffer::specialAsT() { + return reinterpret_cast(_specialBuffer); } -} - -//////////////////////////////////////////////////////////////////////// -void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) { - - _isOwnerPrimary = isOwnerPrimary; - _isOwnerSpecial = isOwnerSpecial; -} - -//////////////////////////////////////////////////////////////////////// -void DataBuffer::deletePrimary() { - - if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) { - auto p = reinterpret_cast(_primaryBuffer); - RELEASE(p, _workspace); - _primaryBuffer = nullptr; - _isOwnerPrimary = false; - } -} - -//////////////////////////////////////////////////////////////////////// -void DataBuffer::deleteBuffers() { - - deletePrimary(); - deleteSpecial(); - _lenInBytes = 0; -} - -//////////////////////////////////////////////////////////////////////// -DataBuffer::~DataBuffer() { - - deleteBuffers(); -} - } diff --git a/libnd4j/include/array/DataType.h b/libnd4j/include/array/DataType.h index b3e21840d..8ec55342e 100644 --- a/libnd4j/include/array/DataType.h +++ b/libnd4j/include/array/DataType.h @@ -42,6 +42,8 @@ namespace nd4j { QINT16 = 16, BFLOAT16 = 17, UTF8 = 50, + UTF16 = 51, + UTF32 = 52, ANY = 100, AUTO = 200, }; diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index 4e879d247..7561e96cc 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -91,6 +91,10 @@ namespace nd4j { template FORCEINLINE static bool castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo); + + template + // struct scalarTypesForNDarray { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; + struct scalarTypesForNDarray { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; }; diff --git a/libnd4j/include/array/InteropDataBuffer.h b/libnd4j/include/array/InteropDataBuffer.h new file mode 100644 index 000000000..3cbfc2f94 --- /dev/null +++ b/libnd4j/include/array/InteropDataBuffer.h @@ -0,0 +1,71 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +#ifndef LIBND4J_INTEROPDATABUFFER_H +#define LIBND4J_INTEROPDATABUFFER_H + +namespace nd4j { + /** + * This class is a wrapper for DataBuffer, suitable for sharing DataBuffer between front-end and back-end languages + */ + class ND4J_EXPORT InteropDataBuffer { + private: + std::shared_ptr _dataBuffer; + uint64_t _offset = 0; + public: + InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset); + InteropDataBuffer(std::shared_ptr databuffer); + InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth); + ~InteropDataBuffer() = default; + +#ifndef __JAVACPP_HACK__ + std::shared_ptr getDataBuffer() const; + std::shared_ptr dataBuffer(); +#endif + + void* primary() const; + void* special() const; + + uint64_t offset() const ; + void setOffset(uint64_t offset); + + void setPrimary(void* ptr, size_t length); + void setSpecial(void* ptr, size_t length); + + void expand(size_t newlength); + + int deviceId() const; + void setDeviceId(int deviceId); + + static void registerSpecialUse(const std::vector& writeList, const std::vector& readList); + static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + + static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList); + static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + }; +} + + +#endif //LIBND4J_INTEROPDATABUFFER_H diff --git a/libnd4j/include/array/NDArrayList.h b/libnd4j/include/array/NDArrayList.h index 358469602..843b69a91 100644 --- a/libnd4j/include/array/NDArrayList.h +++ b/libnd4j/include/array/NDArrayList.h @@ -25,7 +25,7 @@ #include #include -#include +#include #include #include #include diff --git a/libnd4j/include/array/ShapeDescriptor.h b/libnd4j/include/array/ShapeDescriptor.h index 25839cfa9..ddfd45a38 100644 --- a/libnd4j/include/array/ShapeDescriptor.h +++ b/libnd4j/include/array/ShapeDescriptor.h @@ -21,7 +21,7 @@ #ifndef DEV_TESTS_SHAPEDESCRIPTOR_H #define DEV_TESTS_SHAPEDESCRIPTOR_H -#include +#include #include #include #include diff --git a/libnd4j/include/array/cpu/DataBuffer.cpp b/libnd4j/include/array/cpu/DataBuffer.cpp index d13ca0def..ccd782adc 100644 --- a/libnd4j/include/array/cpu/DataBuffer.cpp +++ b/libnd4j/include/array/cpu/DataBuffer.cpp @@ -23,6 +23,24 @@ #include namespace nd4j { + void DataBuffer::expand(const uint64_t size) { + if (size > _lenInBytes) { + // allocate new buffer + int8_t *newBuffer = nullptr; + ALLOCATE(newBuffer, _workspace, size, int8_t); + + // copy data from existing buffer + std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); + + if (_isOwnerPrimary) { + RELEASE(reinterpret_cast(_primaryBuffer), _workspace); + } + + _primaryBuffer = newBuffer; + _lenInBytes = size; + _isOwnerPrimary = true; + } + } //////////////////////////////////////////////////////////////////////// void DataBuffer::setCountersToZero() { @@ -99,14 +117,17 @@ void DataBuffer::allocateSpecial() { void DataBuffer::migrate() { } -/////////////////////////////////////////////////////////////////////// -void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { - if (src._lenInBytes < dst._lenInBytes) - throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination"); - std::memcpy(dst._primaryBuffer, src._primaryBuffer, dst._lenInBytes); +///////////////////////// +void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { + if (src._lenInBytes > dst._lenInBytes) + throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination"); + + std::memcpy(dst._primaryBuffer, src._primaryBuffer, src._lenInBytes); + dst.readPrimary(); } + //////////////////////////////////////////////////////////////////////// void DataBuffer::writePrimary() const { } void DataBuffer::writeSpecial() const { } diff --git a/libnd4j/include/array/cuda/DataBuffer.cu b/libnd4j/include/array/cuda/DataBuffer.cu index 5cb227e69..28e0c432f 100644 --- a/libnd4j/include/array/cuda/DataBuffer.cu +++ b/libnd4j/include/array/cuda/DataBuffer.cu @@ -23,22 +23,72 @@ #include #include #include +#include +#include +#include namespace nd4j { + void DataBuffer::expand(const uint64_t size) { + if (size > _lenInBytes) { + // allocate new buffer + int8_t *newBuffer = nullptr; + int8_t *newSpecialBuffer = nullptr; + ALLOCATE_SPECIAL(newSpecialBuffer, _workspace, size, int8_t); + + // copy data from existing buffer + if (_primaryBuffer != nullptr) { + // there's non-zero chance that primary buffer doesn't exist yet + ALLOCATE(newBuffer, _workspace, size, int8_t); + std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); + + if (_isOwnerPrimary) { + auto ipb = reinterpret_cast(_primaryBuffer); + RELEASE(ipb, _workspace); + } + + _primaryBuffer = newBuffer; + _isOwnerPrimary = true; + } + + cudaMemcpy(newSpecialBuffer, _specialBuffer, _lenInBytes, cudaMemcpyDeviceToDevice); + + if (_isOwnerSpecial) { + auto isb = reinterpret_cast(_specialBuffer); + RELEASE_SPECIAL(isb, _workspace); + } + + _specialBuffer = newSpecialBuffer; + _lenInBytes = size; + _isOwnerSpecial = true; + } + } //////////////////////////////////////////////////////////////////////// void DataBuffer::allocateSpecial() { if (_specialBuffer == nullptr && getLenInBytes() > 0) { + auto deviceId = nd4j::AffinityManager::currentDeviceId(); + + if (_workspace == nullptr) + if (!nd4j::memory::MemoryCounter::getInstance()->validate(getLenInBytes())) + throw nd4j::allocation_exception::build("Requested amount exceeds device limits", nd4j::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), getLenInBytes()); + + ALLOCATE_SPECIAL(_specialBuffer, _workspace, getLenInBytes(), int8_t); _isOwnerSpecial = true; + + if (_workspace == nullptr) { + nd4j::memory::MemoryCounter::getInstance()->countIn(deviceId, getLenInBytes()); + nd4j::memory::MemoryCounter::getInstance()->countIn(nd4j::memory::MemoryType::DEVICE, getLenInBytes()); + } } } //////////////////////////////////////////////////////////////////////// void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) { - if(isPrimaryActual() && !forceSync) + if(isPrimaryActual() && !forceSync) { return; + } allocatePrimary(); @@ -46,7 +96,9 @@ void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSyn if (res != 0) throw cuda_exception::build("DataBuffer::syncToPrimary failed to to some previous kernel failre", res); - cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost); + res = cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost); + if (res != 0) + throw cuda_exception::build("DataBuffer::syncToPrimary cudaMemcpy failed", res); readPrimary(); } @@ -54,13 +106,19 @@ void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSyn //////////////////////////////////////////////////////////////////////// void DataBuffer::syncToSpecial(const bool forceSync) { - - if(isSpecialActual() && !forceSync) + // in this case there's nothing to do here + if (_primaryBuffer == nullptr) return; + if(isSpecialActual() && !forceSync) { + return; + } + allocateSpecial(); - cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice); + auto res = cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice); + if (res != 0) + throw cuda_exception::build("DataBuffer::syncToSpecial cudaMemcpy failed", res); readSpecial(); } @@ -74,6 +132,12 @@ void DataBuffer::deleteSpecial() { RELEASE_SPECIAL(p, _workspace); _specialBuffer = nullptr; _isOwnerSpecial = false; + + // count out towards DataBuffer device, only if we're not in workspace + if (_workspace == nullptr) { + nd4j::memory::MemoryCounter::getInstance()->countOut(_deviceId, getLenInBytes()); + nd4j::memory::MemoryCounter::getInstance()->countOut(nd4j::memory::MemoryType::DEVICE, getLenInBytes()); + } } } @@ -97,19 +161,6 @@ void DataBuffer::copyCounters(const DataBuffer& other) { _readPrimary.store(other._writeSpecial); _readSpecial.store(other._writePrimary); } -//////////////////////////////////////////////////////////////////////// -void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { - if (src._lenInBytes < dst._lenInBytes) - throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination"); - - if (src.isSpecialActual()) { - cudaMemcpy(dst._specialBuffer, src._specialBuffer, dst.getLenInBytes(), cudaMemcpyDeviceToDevice); - } else if (src.isPrimaryActual()) { - cudaMemcpy(dst._specialBuffer, src._primaryBuffer, dst.getLenInBytes(), cudaMemcpyHostToDevice); - } - - dst.writeSpecial(); -} //////////////////////////////////////////////////////////////////////// void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer @@ -176,8 +227,11 @@ void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate s //////////////////////////////////////////////////////////////////////// void DataBuffer::setToZeroBuffers(const bool both) { + cudaMemsetAsync(special(), 0, getLenInBytes(), *LaunchContext::defaultContext()->getCudaStream()); + auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build("DataBuffer::setToZeroBuffers: streamSync failed!", res); - cudaMemset(special(), 0, getLenInBytes()); writeSpecial(); if(both) { @@ -186,12 +240,37 @@ void DataBuffer::setToZeroBuffers(const bool both) { } } +///////////////////////// +void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { + if (src._lenInBytes > dst._lenInBytes) + throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination"); + + + int res = 0; + if (src.isSpecialActual()) { + res = cudaMemcpyAsync(dst._specialBuffer, src._specialBuffer, src.getLenInBytes(), cudaMemcpyDeviceToDevice, *LaunchContext::defaultContext()->getCudaStream()); + } else if (src.isPrimaryActual()) { + res = cudaMemcpyAsync(dst._specialBuffer, src._primaryBuffer, src.getLenInBytes(), cudaMemcpyHostToDevice, *LaunchContext::defaultContext()->getCudaStream()); + } + + if (res != 0) + throw cuda_exception::build("DataBuffer::memcpy: cudaMemcpyAsync failed!", res); + + res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build("DataBuffer::memcpy: streamSync failed!", res); + + dst.writeSpecial(); +} + //////////////////////////////////////////////////////////////////////// void DataBuffer::migrate() { memory::Workspace* newWorkspace = nullptr; void* newBuffer; ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t); - cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice); + auto res = cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice); + if (res != 0) + throw cuda_exception::build("DataBuffer::migrate: cudaMemcpyAsync failed!", res); if (_isOwnerSpecial) { // now we're releasing original buffer @@ -203,7 +282,7 @@ void DataBuffer::migrate() { } //////////////////////////////////////////////////////////////////////// -void DataBuffer::writePrimary() const { _writePrimary = ++_counter; } +void DataBuffer::writePrimary() const {_writePrimary = ++_counter; } void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; } void DataBuffer::readPrimary() const { _readPrimary = ++_counter; } void DataBuffer::readSpecial() const { _readSpecial = ++_counter; } diff --git a/libnd4j/include/array/impl/DataBuffer.cpp b/libnd4j/include/array/impl/DataBuffer.cpp new file mode 100644 index 000000000..49527026c --- /dev/null +++ b/libnd4j/include/array/impl/DataBuffer.cpp @@ -0,0 +1,333 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include +#include +#include +#include + +namespace nd4j { + ///// IMLEMENTATION OF COMMON METHODS ///// + + +//////////////////////////////////////////////////////////////////////// +// default constructor + DataBuffer::DataBuffer() { + + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + _lenInBytes = 0; + _dataType = INT8; + _workspace = nullptr; + _isOwnerPrimary = false; + _isOwnerSpecial = false; + _deviceId = nd4j::AffinityManager::currentDeviceId(); + + setCountersToZero(); + } + +//////////////////////////////////////////////////////////////////////// +// copy constructor + DataBuffer::DataBuffer(const DataBuffer &other) { + + throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!"); + + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + + _deviceId.store(other._deviceId.load()); + + setCountersToZero(); + + allocateBuffers(); + copyBufferFrom(other); + } + +//////////////////////////////////////////////////////////////////////// + DataBuffer::DataBuffer(void* primary, void* special, + const size_t lenInBytes, const DataType dataType, + const bool isOwnerPrimary, const bool isOwnerSpecial, + memory::Workspace* workspace) { + + if (primary == nullptr && special == nullptr) + throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !"); + + _primaryBuffer = primary; + _specialBuffer = special; + _lenInBytes = lenInBytes; + _dataType = dataType; + _workspace = workspace; + _isOwnerPrimary = isOwnerPrimary; + _isOwnerSpecial = isOwnerSpecial; + _deviceId = nd4j::AffinityManager::currentDeviceId(); + + setCountersToZero(); + + if(primary != nullptr) + readPrimary(); + if(special != nullptr) + readSpecial(); + } + +//////////////////////////////////////////////////////////////////////// + DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace): + DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) { + + syncToSpecial(true); + } + +//////////////////////////////////////////////////////////////////////// +// copies data from hostBuffer to own memory buffer + DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) { + + if (hostBuffer == nullptr) + throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !"); + if (lenInBytes == 0) + throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !"); + + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + _lenInBytes = lenInBytes; + _dataType = dataType; + _workspace = workspace; + + _deviceId = nd4j::AffinityManager::currentDeviceId(); + + setCountersToZero(); + + allocateBuffers(); + + copyBufferFromHost(hostBuffer, lenInBytes); + } + +//////////////////////////////////////////////////////////////////////// + DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) { + + _dataType = dataType; + _workspace = workspace; + _lenInBytes = lenInBytes; + + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + + _deviceId = nd4j::AffinityManager::currentDeviceId(); + + setCountersToZero(); + + if(lenInBytes != 0) { + allocateBuffers(allocBoth); + writeSpecial(); + } + } + +//////////////////////////////////////////////////////////////////////// +// move constructor + DataBuffer::DataBuffer(DataBuffer&& other) { + + _primaryBuffer = other._primaryBuffer; + _specialBuffer = other._specialBuffer; + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + _isOwnerPrimary = other._isOwnerPrimary; + _isOwnerSpecial = other._isOwnerSpecial; + _deviceId.store(other._deviceId); + + copyCounters(other); + + other._primaryBuffer = other._specialBuffer = nullptr; + other.setAllocFlags(false, false); + other._lenInBytes = 0; + } + +//////////////////////////////////////////////////////////////////////// +// assignment operator + DataBuffer& DataBuffer::operator=(const DataBuffer& other) { + + if (this == &other) + return *this; + + deleteBuffers(); + + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + + allocateBuffers(); + copyBufferFrom(other); + + return *this; + } + +//////////////////////////////////////////////////////////////////////// +// move assignment operator + DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept { + + if (this == &other) + return *this; + + deleteBuffers(); + + _primaryBuffer = other._primaryBuffer; + _specialBuffer = other._specialBuffer; + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + _isOwnerPrimary = other._isOwnerPrimary; + _isOwnerSpecial = other._isOwnerSpecial; + + copyCounters(other); + + other._primaryBuffer = other._specialBuffer = nullptr; + other.setAllocFlags(false, false); + other._lenInBytes = 0; + + return *this; + } + +//////////////////////////////////////////////////////////////////////// + void* DataBuffer::primary() { + return _primaryBuffer; + } + +//////////////////////////////////////////////////////////////////////// + void* DataBuffer::special() { + return _specialBuffer; + } + +//////////////////////////////////////////////////////////////////////// + DataType DataBuffer::getDataType() { + return _dataType; + } + +//////////////////////////////////////////////////////////////////////// + size_t DataBuffer::getLenInBytes() const { + return _lenInBytes; + } + + +//////////////////////////////////////////////////////////////////////// + void DataBuffer::allocatePrimary() { + + if (_primaryBuffer == nullptr && getLenInBytes() > 0) { + auto deviceId = nd4j::AffinityManager::currentDeviceId(); + // check if this allocation won't bring us above limit + if (_workspace == nullptr) { + if (Environment::getInstance()->isCPU()) { + // on cpu backend we validate against device 0 for now + if (!nd4j::memory::MemoryCounter::getInstance()->validate(getLenInBytes())) + throw nd4j::allocation_exception::build("Requested amount exceeds HOST device limits", nd4j::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), getLenInBytes()); + } else { + // in heterogenous mode we valdate against device group + if (!nd4j::memory::MemoryCounter::getInstance()->validateGroup(nd4j::memory::MemoryType::HOST, getLenInBytes())) + throw nd4j::allocation_exception::build("Requested amount exceeds HOST group limits", nd4j::memory::MemoryCounter::getInstance()->groupLimit(nd4j::memory::MemoryType::HOST), getLenInBytes()); + } + } + + ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t); + _isOwnerPrimary = true; + + // count in towards current deviceId if we're not in workspace mode + if (_workspace == nullptr) { + if (Environment::getInstance()->isCPU()) // we don't want this counter to be added to CUDA device + nd4j::memory::MemoryCounter::getInstance()->countIn(deviceId, getLenInBytes()); + + nd4j::memory::MemoryCounter::getInstance()->countIn(nd4j::memory::MemoryType::HOST, getLenInBytes()); + } + } + } + +//////////////////////////////////////////////////////////////////////// + void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) { + _isOwnerPrimary = isOwnerPrimary; + _isOwnerSpecial = isOwnerSpecial; + } + +//////////////////////////////////////////////////////////////////////// + void DataBuffer::deletePrimary() { + + if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) { + auto p = reinterpret_cast(_primaryBuffer); + RELEASE(p, _workspace); + _primaryBuffer = nullptr; + _isOwnerPrimary = false; + + + // count out towards DataBuffer device, only if we're not in workspace + if (_workspace == nullptr) { + if (Environment::getInstance()->isCPU()) + nd4j::memory::MemoryCounter::getInstance()->countOut(_deviceId, getLenInBytes()); + + nd4j::memory::MemoryCounter::getInstance()->countOut(nd4j::memory::MemoryType::HOST, getLenInBytes()); + } + } + } + +//////////////////////////////////////////////////////////////////////// + void DataBuffer::deleteBuffers() { + + deletePrimary(); + deleteSpecial(); + _lenInBytes = 0; + } + +//////////////////////////////////////////////////////////////////////// + DataBuffer::~DataBuffer() { + + deleteBuffers(); + } + + void DataBuffer::setPrimaryBuffer(void *buffer, size_t length) { + if (_primaryBuffer != nullptr && _isOwnerPrimary) { + deletePrimary(); + } + _primaryBuffer = buffer; + _isOwnerPrimary = false; + _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); + } + + void DataBuffer::setSpecialBuffer(void *buffer, size_t length) { + this->setSpecial(buffer, false); + _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); + } + + void DataBuffer::setDataType(DataType dataType) { + _dataType = dataType; + } + + int DataBuffer::deviceId() const { + return _deviceId.load(); + } + + void DataBuffer::close() { + this->deleteBuffers(); + } + + void DataBuffer::setDeviceId(int deviceId) { + _deviceId = deviceId; + } +} diff --git a/libnd4j/include/array/impl/InteropDataBuffer.cpp b/libnd4j/include/array/impl/InteropDataBuffer.cpp new file mode 100644 index 000000000..cffc1462b --- /dev/null +++ b/libnd4j/include/array/impl/InteropDataBuffer.cpp @@ -0,0 +1,146 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +namespace nd4j { + InteropDataBuffer::InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset) { + _dataBuffer = dataBuffer.getDataBuffer(); + + // offset is always absolute to the original buffer + _offset = offset; + + if (_offset + length > _dataBuffer->getLenInBytes()) { + throw std::runtime_error("offset + length is higher than original length"); + } + } + + InteropDataBuffer::InteropDataBuffer(std::shared_ptr databuffer) { + _dataBuffer = databuffer; + } + + InteropDataBuffer::InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth) { + if (elements == 0) { + _dataBuffer = std::make_shared(); + _dataBuffer->setDataType(dtype); + } else { + _dataBuffer = std::make_shared(elements, dtype, nullptr, allocateBoth); + } + } + + std::shared_ptr InteropDataBuffer::getDataBuffer() const { + return _dataBuffer; + } + + std::shared_ptr InteropDataBuffer::dataBuffer() { + return _dataBuffer; + } + + void* InteropDataBuffer::primary() const { + return reinterpret_cast(_dataBuffer->primary()) + _offset; + } + + void* InteropDataBuffer::special() const { + return reinterpret_cast(_dataBuffer->special()) + _offset; + } + + void InteropDataBuffer::setPrimary(void* ptr, size_t length) { + _dataBuffer->setPrimaryBuffer(ptr, length); + } + + void InteropDataBuffer::setSpecial(void* ptr, size_t length) { + _dataBuffer->setSpecialBuffer(ptr, length); + } + + uint64_t InteropDataBuffer::offset() const { + return _offset; + } + + void InteropDataBuffer::setOffset(uint64_t offset) { + _offset = offset; + } + + int InteropDataBuffer::deviceId() const { + return _dataBuffer->deviceId(); + } + + + void InteropDataBuffer::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { + for (const auto &v:writeList) { + if (v == nullptr) + continue; + + v->getDataBuffer()->writeSpecial(); + } + } + + void InteropDataBuffer::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { + auto currentDeviceId = nd4j::AffinityManager::currentDeviceId(); + for (const auto &v:readList) { + if (v == nullptr) + continue; + + if (v->getDataBuffer()->deviceId() != currentDeviceId) + v->getDataBuffer()->migrate(); + + v->getDataBuffer()->syncToSpecial(); + } + + // we don't tick write list, only ensure the same device affinity + for (const auto &v:writeList) { + if (v == nullptr) + continue; + + // special case for legacy ops - views can be updated on host side, thus original array can be not updated + if (!v->getDataBuffer()->isSpecialActual()) + v->getDataBuffer()->syncToSpecial(); + + if (v->getDataBuffer()->deviceId() != currentDeviceId) + v->getDataBuffer()->migrate(); + } + } + + void InteropDataBuffer::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { + for (const auto &v:writeList) { + if (v == nullptr) + continue; + } + } + + void InteropDataBuffer::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { + for (const auto &v:readList) { + if (v == nullptr) + continue; + + v->getDataBuffer()->syncToPrimary(LaunchContext::defaultContext()); + } + } + + void InteropDataBuffer::expand(size_t newlength) { + _dataBuffer->expand(newlength * DataTypeUtils::sizeOf(_dataBuffer->getDataType())); + } + + void InteropDataBuffer::setDeviceId(int deviceId) { + _dataBuffer->setDeviceId(deviceId); + } +} diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index 75df72e70..cb1461226 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -44,7 +44,7 @@ namespace nd4j { } NDArray* NDArrayList::read(int idx) { - return readRaw(idx)->dup(); + return new NDArray(readRaw(idx)->dup()); } nd4j::DataType NDArrayList::dataType() { @@ -114,7 +114,7 @@ namespace nd4j { } else return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } - + //_elements++; // storing reference @@ -136,11 +136,10 @@ namespace nd4j { std::vector args({axis}); auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args); auto result = array->allTensorsAlongDimension(newAxis); - for (int e = 0; e < result->size(); e++) { - auto chunk = result->at(e);//->dup(array->ordering()); - write(e, chunk->dup(array->ordering())); + for (int e = 0; e < result.size(); e++) { + auto chunk = result.at(e);//->dup(array->ordering()); + write(e, new NDArray(chunk->dup(array->ordering()))); } - delete result; } NDArray* NDArrayList::stack() { @@ -161,7 +160,7 @@ namespace nd4j { auto result = op.execute(inputs, {}, {}, {}); - auto array = result->at(0)->dup(); + auto array = new NDArray(result->at(0)->dup()); delete result; @@ -214,13 +213,11 @@ namespace nd4j { auto tads = array->allTensorsAlongDimension(axis); int indicesSize = indices.size(); - if (tads->size() != indicesSize) + if (tads.size() != indicesSize) throw std::runtime_error("Number of TADs should match number of indices"); for (int e = 0; e < indicesSize; e++) - tads->at(e)->assign(_chunks[indices[e]]); - - delete tads; + tads.at(e)->assign(_chunks[indices[e]]); return array; } @@ -234,7 +231,7 @@ namespace nd4j { list->_elements.store(_elements.load()); for (auto const& v : _chunks) { - list->_chunks[v.first] = v.second->dup(); + list->_chunks[v.first] = new NDArray(v.second->dup()); } return list; diff --git a/libnd4j/include/config.h.in b/libnd4j/include/config.h.in index bdba3cc03..1e63552d0 100644 --- a/libnd4j/include/config.h.in +++ b/libnd4j/include/config.h.in @@ -13,4 +13,8 @@ #cmakedefine FLATBUFFERS_PATH "@FLATBUFFERS_PATH@" +#cmakedefine HAVE_CUDNN + +#cmakedefine DEFAULT_ENGINE @DEFAULT_ENGINE@ + #endif diff --git a/libnd4j/include/exceptions/allocation_exception.h b/libnd4j/include/exceptions/allocation_exception.h index 29756d253..458650037 100644 --- a/libnd4j/include/exceptions/allocation_exception.h +++ b/libnd4j/include/exceptions/allocation_exception.h @@ -40,6 +40,7 @@ namespace nd4j { ~allocation_exception() = default; static allocation_exception build(std::string message, Nd4jLong bytes); + static allocation_exception build(std::string message, Nd4jLong limit, Nd4jLong bytes); }; } diff --git a/libnd4j/include/exceptions/impl/allocation_exception.cpp b/libnd4j/include/exceptions/impl/allocation_exception.cpp index 76c6338da..85c3e72aa 100644 --- a/libnd4j/include/exceptions/impl/allocation_exception.cpp +++ b/libnd4j/include/exceptions/impl/allocation_exception.cpp @@ -31,4 +31,11 @@ namespace nd4j { message += "; Requested bytes: [" + bytes + "]"; return allocation_exception(message); } + + allocation_exception allocation_exception::build(std::string message, Nd4jLong limit, Nd4jLong numBytes) { + auto bytes = StringUtils::valueToString(numBytes); + auto lim = StringUtils::valueToString(limit); + message += "; Limit bytes: [" + lim + "]; Requested bytes: [" + bytes + "]"; + return allocation_exception(message); + } } \ No newline at end of file diff --git a/libnd4j/include/execution/Engine.h b/libnd4j/include/execution/Engine.h new file mode 100644 index 000000000..cd30867a9 --- /dev/null +++ b/libnd4j/include/execution/Engine.h @@ -0,0 +1,31 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_ENGINE_H +#define SD_ENGINE_H + +namespace samediff { + enum Engine { + ENGINE_CPU = 0, + ENGINE_CUDA = 1, + }; +} + +#endif //SD_ENGINE_H diff --git a/libnd4j/include/execution/Executor.h b/libnd4j/include/execution/Executor.h index 8922e345d..26d5365ad 100644 --- a/libnd4j/include/execution/Executor.h +++ b/libnd4j/include/execution/Executor.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_EXECUTOR_H -#define DEV_TESTS_EXECUTOR_H +#ifndef SD_EXECUTOR_H +#define SD_EXECUTOR_H namespace nd4j { class Executor { @@ -30,4 +30,4 @@ namespace nd4j { }; } -#endif //DEV_TESTS_EXECUTOR_H +#endif //SD_EXECUTOR_H diff --git a/libnd4j/include/execution/LaunchContext.h b/libnd4j/include/execution/LaunchContext.h index 076e2933b..689d79369 100644 --- a/libnd4j/include/execution/LaunchContext.h +++ b/libnd4j/include/execution/LaunchContext.h @@ -27,6 +27,7 @@ #include #include #include +#include "config.h" #endif // used for MKLDNN etc @@ -81,6 +82,7 @@ class ND4J_EXPORT LaunchContext { int* getAllocationPointer() const; void* getCublasHandle() const; void* getCusolverHandle() const; + void* getCuDnnHandle() const; cudaStream_t* getCudaStream() const; cudaStream_t* getCudaSpecialStream() const; diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu index 435858462..e018cf807 100644 --- a/libnd4j/include/execution/cuda/ContextBuffers.cu +++ b/libnd4j/include/execution/cuda/ContextBuffers.cu @@ -138,7 +138,7 @@ namespace nd4j { if (res != 0) throw cuda_exception::build("_reductionPointer allocation failed", res); - res = cudaMalloc(reinterpret_cast(&_scalarPointer), 16); + res = cudaHostAlloc(reinterpret_cast(&_scalarPointer), 16, cudaHostAllocDefault); if (res != 0) throw cuda_exception::build("_scalarPointer allocation failed", res); diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu index 5e2ac589c..3145ca8d3 100644 --- a/libnd4j/include/execution/cuda/LaunchContext.cu +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -166,6 +166,10 @@ LaunchContext::LaunchContext() { return contextBuffers.isInitialized(); } + void* LaunchContext::getCuDnnHandle() const { + return CublasHelper::getInstance()->cudnn(); + } + sd::ErrorReference* LaunchContext::errorReference() { return contextBuffers.errorReference(); } diff --git a/libnd4j/include/execution/impl/Threads.cpp b/libnd4j/include/execution/impl/Threads.cpp index f5ae5b5eb..982b59a4c 100644 --- a/libnd4j/include/execution/impl/Threads.cpp +++ b/libnd4j/include/execution/impl/Threads.cpp @@ -492,7 +492,7 @@ namespace samediff { auto itersY = delta_y / incY; auto itersZ = delta_z / incZ; - numThreads = 1; //ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ); + numThreads = ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ); if (numThreads == 1) { // loop is too small - executing function as is function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index d5b85b543..f4fa6d16d 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -27,6 +27,7 @@ #include #include #include +#include // CUDA-specific includes #ifdef __CUDACC__ @@ -64,6 +65,9 @@ namespace nd4j { std::vector _handles; bool _helpersAllowed = true; + + // in some cases we might be able to skip shape function for validation purposes + bool _shapeFunctionOverride = false; public: Context(ContextPrototype* prototype, VariableSpace* variableSpace); @@ -99,12 +103,13 @@ namespace nd4j { // this method returns workspace for object allocations nd4j::memory::Workspace* oWorkspace(); - void setVariableSpace(VariableSpace* variableSpace); nd4j::random::RandomBuffer* getRNG(); void setRNG(nd4j::random::RandomBuffer* rng); + void setTargetEngine(samediff::Engine engine); + VariableSpace *getVariableSpace(); LaunchContext* launchContext(); @@ -182,9 +187,11 @@ namespace nd4j { void setInputArray(int index, NDArray *array, bool removable = false); void setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); + void setInputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo); void setOutputArray(int index, NDArray *array, bool removable = false); void setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); + void setOutputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo); void setTArguments(double *arguments, int numberOfArguments); void setIArguments(Nd4jLong *arguments, int numberOfArguments); @@ -196,9 +203,11 @@ namespace nd4j { void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); - void allowHelpers(bool reallyAllow); bool helpersAllowed(); + + void setShapeFunctionOverride(bool reallyOverride); + bool shapeFunctionOverride(); }; } } diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index 8aaa3e3d2..a9d05b7b4 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -27,6 +27,11 @@ #include #include #include +#include + +#ifndef __STANDALONE_BUILD__ +#include +#endif namespace nd4j { namespace graph { @@ -53,6 +58,8 @@ namespace nd4j { nd4j::ops::OpDescriptor* _opDescriptor; bool _useMKLDNN = nd4j::Environment::getInstance()->isUseMKLDNN(); + // target engine for execution + samediff::Engine _engine = DEFAULT_ENGINE; public: explicit ContextPrototype(nd4j::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false); ~ContextPrototype() = default; @@ -84,6 +91,8 @@ namespace nd4j { std::vector* getBArguments(); std::vector* getAxis(); + samediff::Engine engine(); + size_t numT(); size_t numI(); size_t numB(); diff --git a/libnd4j/include/graph/ExecutionResult.h b/libnd4j/include/graph/ExecutionResult.h index 7a632a998..b1a1b1737 100644 --- a/libnd4j/include/graph/ExecutionResult.h +++ b/libnd4j/include/graph/ExecutionResult.h @@ -23,7 +23,7 @@ #include #include -#include +#include #include #include #include diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 18c23ec0f..00efb3c52 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -23,7 +23,7 @@ #include #include -#include +#include //#include #include #include diff --git a/libnd4j/include/graph/GraphHolder.h b/libnd4j/include/graph/GraphHolder.h index a60d088e7..f740ad4ca 100644 --- a/libnd4j/include/graph/GraphHolder.h +++ b/libnd4j/include/graph/GraphHolder.h @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include #include diff --git a/libnd4j/include/graph/GraphState.h b/libnd4j/include/graph/GraphState.h index 95f8a016d..52c6f9e16 100644 --- a/libnd4j/include/graph/GraphState.h +++ b/libnd4j/include/graph/GraphState.h @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/libnd4j/include/graph/RandomGenerator.h b/libnd4j/include/graph/RandomGenerator.h index e58f415c5..de475b8f8 100644 --- a/libnd4j/include/graph/RandomGenerator.h +++ b/libnd4j/include/graph/RandomGenerator.h @@ -28,6 +28,7 @@ #include #include #include +#include #ifdef __CUDACC__ #include @@ -46,7 +47,10 @@ namespace nd4j { public: void *operator new(size_t len) { void *ptr; - cudaHostAlloc(&ptr, len, cudaHostAllocDefault); + auto res = cudaHostAlloc(&ptr, len, cudaHostAllocDefault); + if (res != 0) + throw std::runtime_error("CudaManagedRandomGenerator: failed to allocate memory"); + return ptr; } diff --git a/libnd4j/include/graph/Scope.h b/libnd4j/include/graph/Scope.h index 4e322b5b2..5cbbf8bc0 100644 --- a/libnd4j/include/graph/Scope.h +++ b/libnd4j/include/graph/Scope.h @@ -22,7 +22,7 @@ #define LIBND4J_SCOPE_H #include -#include +#include #include namespace nd4j { diff --git a/libnd4j/include/graph/Stash.h b/libnd4j/include/graph/Stash.h index 6613ec859..83a7ec066 100644 --- a/libnd4j/include/graph/Stash.h +++ b/libnd4j/include/graph/Stash.h @@ -23,7 +23,7 @@ //#include #include -#include +#include #include #include #include diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index 21bdc608b..9443d34b1 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/libnd4j/include/graph/execution/impl/LogicConditional.cpp b/libnd4j/include/graph/execution/impl/LogicConditional.cpp index 62a533ee7..fb1f0fa1e 100644 --- a/libnd4j/include/graph/execution/impl/LogicConditional.cpp +++ b/libnd4j/include/graph/execution/impl/LogicConditional.cpp @@ -48,7 +48,7 @@ namespace nd4j { } else { // FIXME: in some cases it's possible to have no NDArray if (inputVar->hasNDArray()) - innerVar->setNDArray(inputVar->getNDArray()->dup()); + innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); } } diff --git a/libnd4j/include/graph/execution/impl/LogicWhile.cpp b/libnd4j/include/graph/execution/impl/LogicWhile.cpp index bdabdc6bc..147c35248 100644 --- a/libnd4j/include/graph/execution/impl/LogicWhile.cpp +++ b/libnd4j/include/graph/execution/impl/LogicWhile.cpp @@ -56,7 +56,7 @@ namespace nd4j { } else { // FIXME: in some cases it's possible to have no NDArray if (inputVar->hasNDArray()) - innerVar->setNDArray(inputVar->getNDArray()->dup()); + innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); } } diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index ed9321ccd..4876675dc 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -21,6 +21,7 @@ #include #include #include +#include namespace nd4j { @@ -106,6 +107,10 @@ namespace nd4j { delete _context; } + void Context::setTargetEngine(samediff::Engine engine) { + _engine = engine; + } + bool Context::hasWorkspaceProvided() { return this->_workspace != nullptr; } @@ -426,6 +431,44 @@ namespace nd4j { array->setContext(_context); } + void Context::setInputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) { + auto dataBuffer = reinterpret_cast(vdatabuffer); + + if (_fastpath_in.size() < index + 1) + _fastpath_in.resize(index+1); + + NDArray *array; + if (dataBuffer != nullptr) + array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), nd4j::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); + else + array = new NDArray(nullptr, nullptr, reinterpret_cast(shapeInfo)); + + _fastpath_in[index] = array; + _handles.emplace_back(array); + + if (_context != nullptr) + array->setContext(_context); + } + + void Context::setOutputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) { + auto dataBuffer = reinterpret_cast(vdatabuffer); + + if (_fastpath_out.size() < index + 1) + _fastpath_out.resize(index+1); + + NDArray *array; + if (dataBuffer != nullptr) + array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), nd4j::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); + else + array = new NDArray(nullptr, nullptr, reinterpret_cast(shapeInfo)); + + _fastpath_out[index] = array; + _handles.emplace_back(array); + + if (_context != nullptr) + array->setContext(_context); + } + void Context::setTArguments(double *arguments, int numberOfArguments) { _tArgs.clear(); _tArgs.reserve(numberOfArguments); @@ -484,6 +527,14 @@ namespace nd4j { for (auto b:bArgs) _bArgs.push_back(b); } + + void Context::setShapeFunctionOverride(bool reallyOverride) { + _shapeFunctionOverride = reallyOverride; + } + + bool Context::shapeFunctionOverride() { + return _shapeFunctionOverride; + } } } diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index 5bd2a69e7..0ddde97f4 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -59,6 +59,10 @@ namespace nd4j { } } + samediff::Engine ContextPrototype::engine() { + return _engine; + } + bool ContextPrototype::hasVariablesFilled() { return this->_inputs.size() > 0; } diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index d77bded2e..5b8f00b25 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -40,7 +40,7 @@ namespace nd4j { result->setIndex(this->_index); if (this->_ndarray != nullptr) - result->setNDArray(this->_ndarray->template asT()); + result->setNDArray(new NDArray(this->_ndarray->template asT())); // FIXME: add support for ArrayList if (this->_list != nullptr) { @@ -61,7 +61,7 @@ namespace nd4j { result->_index = this->_index; if (this->_ndarray != nullptr) - result->_ndarray = this->_ndarray->dup(this->_ndarray->ordering()); + result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering())); if (this->_list != nullptr) result->_list = this->_list->clone(); diff --git a/libnd4j/include/graph/scheme/array.fbs b/libnd4j/include/graph/scheme/array.fbs index 91e338500..2ffce58bd 100644 --- a/libnd4j/include/graph/scheme/array.fbs +++ b/libnd4j/include/graph/scheme/array.fbs @@ -43,6 +43,8 @@ enum DType:byte { QINT16, BFLOAT16 = 17, UTF8 = 50, + UTF16 = 51, + UTF32 = 52, } // this structure describe NDArray diff --git a/libnd4j/include/helpers/DebugHelper.h b/libnd4j/include/helpers/DebugHelper.h index 945bebe8e..3c3fe1d58 100644 --- a/libnd4j/include/helpers/DebugHelper.h +++ b/libnd4j/include/helpers/DebugHelper.h @@ -34,8 +34,6 @@ #include #include -#define checkCudaErrors(ERR) if (ERR != 0) {throw std::runtime_error("CUDA stream synchronization failed");} - #endif #include namespace nd4j { diff --git a/libnd4j/include/helpers/StringUtils.h b/libnd4j/include/helpers/StringUtils.h index 1a450450f..2a562de4b 100644 --- a/libnd4j/include/helpers/StringUtils.h +++ b/libnd4j/include/helpers/StringUtils.h @@ -25,6 +25,8 @@ #include #include #include +#include +#include namespace nd4j { class ND4J_EXPORT StringUtils { @@ -53,6 +55,36 @@ namespace nd4j { return result; } + + /** + * This method returns number of needle matches within haystack + * PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8 + * + * @param haystack + * @param haystackLength + * @param needle + * @param needleLength + * @return + */ + static uint64_t countSubarrays(const void *haystack, uint64_t haystackLength, const void *needle, uint64_t needleLength); + + /** + * This method returns number of bytes used for string NDArrays content + * PLEASE NOTE: this doesn't include header + * + * @param array + * @return + */ + static uint64_t byteLength(const NDArray &array); + + /** + * This method splits a string into substring by delimiter + * + * @param haystack + * @param delimiter + * @return + */ + static std::vector split(const std::string &haystack, const std::string &delimiter); }; } diff --git a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h index 67ca25e07..d24c31b84 100644 --- a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h @@ -93,7 +93,7 @@ namespace nd4j { } OpBenchmark* clone() override { - return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x == nullptr ? _x : _x->dup() , _y == nullptr ? _y : _y->dup(), _z == nullptr ? _z : _z->dup()); + return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x == nullptr ? _x : new NDArray(_x->dup()) , _y == nullptr ? _y : new NDArray(_y->dup()), _z == nullptr ? _z : new NDArray(_z->dup())); } }; } diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index 0f495bf96..f0e8846e3 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -202,6 +202,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con if(C == nullptr) C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + if (C->isEmpty()) + return C; + const auto aType = A->dataType(); const auto bType = B->dataType(); const auto cType = C->dataType(); @@ -230,17 +233,17 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con bool cNcont = N == 1 || C->strideAt(1) == 1; if(!aMcont && !aKcont) { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); toDelete.push_back(pA); aMcont = true; } if(!bKcont && !bNcont) { - pB = B->dup('f'); + pB = new NDArray(B->dup('f')); toDelete.push_back(pB); bKcont = true; } if(!cMcont && !cNcont) { - pC = C->dup('f'); + pC = new NDArray(C->dup('f')); toDelete.push_back(pC); cMcont = true; } @@ -307,6 +310,9 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* if(Y == nullptr) Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); + if (Y->isEmpty()) + return Y; + const int incx = X->stridesOf()[xLenDim]; const int incy = Y->stridesOf()[yLenDim]; @@ -332,7 +338,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* bool aNcont = N == 1 || A->strideAt(1) == 1; if(!aMcont && !aNcont) { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); aMcont = true; } const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor; @@ -511,6 +517,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con C = new NDArray(outOrder, cExpectedShape, B->dataType()); } + if (C->isEmpty()) + return C; + const int cRank = C->rankOf(); const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); diff --git a/libnd4j/include/helpers/cpu/householder.cpp b/libnd4j/include/helpers/cpu/householder.cpp index 7fa82de8d..024695583 100644 --- a/libnd4j/include/helpers/cpu/householder.cpp +++ b/libnd4j/include/helpers/cpu/householder.cpp @@ -60,11 +60,10 @@ NDArray Householder::evalHHmatrix(const NDArray& x) { w.p(Nd4jLong(0), 1.f); wT.assign(&w); - auto identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext()); + NDArray identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext()); identity.setIdentity(); // identity matrix return identity - mmul(w, wT) * coeff; - } ////////////////////////////////////////////////////////////////////////// @@ -95,9 +94,9 @@ void Householder::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, coeff = -u0 / normX; if(x.isRowVector()) - tail.assign(x({0,0, 1,-1}) / u0); + tail.assign(static_cast(x({0,0, 1,-1})) / u0); else - tail.assign(x({1,-1, 0,0,}) / u0); + tail.assign(static_cast(x({1,-1, 0,0,})) / u0); } } diff --git a/libnd4j/include/helpers/cpu/jacobiSVD.cpp b/libnd4j/include/helpers/cpu/jacobiSVD.cpp index b8a51195e..4ba2bfe0a 100644 --- a/libnd4j/include/helpers/cpu/jacobiSVD.cpp +++ b/libnd4j/include/helpers/cpu/jacobiSVD.cpp @@ -269,7 +269,7 @@ void JacobiSVD::evalData(const NDArray& matrix) { HHcolPivQR qr(matrix / scale); _m.assign(qr._qr({0,_cols, 0,_cols})); - _m.fillAsTriangular(0., 0, 0, 'l'); + _m.fillAsTriangular(0., 0, 0, _m, 'l'); HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); @@ -288,7 +288,7 @@ void JacobiSVD::evalData(const NDArray& matrix) { auto matrixT = matrix.transpose(); HHcolPivQR qr(matrixT / scale); _m.assign(qr._qr({0,_rows, 0,_rows})); - _m.fillAsTriangular(0., 0, 0, 'l'); + _m.fillAsTriangular(0., 0, 0, _m, 'l'); _m.transposei(); HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); // type = 'u' is not mistake here ! @@ -305,7 +305,7 @@ void JacobiSVD::evalData(const NDArray& matrix) { } else { - _m.assign(matrix({0,_diagSize, 0,_diagSize}) / scale); + _m.assign(static_cast(matrix({0,_diagSize, 0,_diagSize})) / scale); if(_calcU) _u.setIdentity(); @@ -366,7 +366,7 @@ void JacobiSVD::evalData(const NDArray& matrix) { _s.p(i, math::nd4j_abs(_m.e(i,i))); if(_calcU && _m.e(i,i) < (T)0.) { auto temp = _u({0,0, i,i+1}, true); - temp.applyTransform(transform::Neg, &temp, nullptr); + temp.applyTransform(transform::Neg, temp, nullptr); } } diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp similarity index 98% rename from libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp rename to libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp index 4bd456da2..1aaaaebc7 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp @@ -311,6 +311,4 @@ void nd4j::IndexReductionLoops::wrapIndexReduce(const int opNum, void* vx, auto extraParams = reinterpret_cast(vextraParams); DISPATCH_BY_OPNUM_TT(loopIndexReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), INDEX_REDUCE_OPS); -} - -BUILD_DOUBLE_TEMPLATE(template void nd4j::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES, INDEXING_TYPES); \ No newline at end of file +} \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32.cpp new file mode 100644 index 000000000..8a4b3cd7d --- /dev/null +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32.cpp @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include "./IndexReductionLoops.hpp" + +BUILD_DOUBLE_TEMPLATE(template void nd4j::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES, (nd4j::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64.cpp new file mode 100644 index 000000000..4fcb63ebf --- /dev/null +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64.cpp @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include "./IndexReductionLoops.hpp" + +BUILD_DOUBLE_TEMPLATE(template void nd4j::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES, (nd4j::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/svd.cpp b/libnd4j/include/helpers/cpu/svd.cpp index 38d3b9ff4..4bf2be639 100644 --- a/libnd4j/include/helpers/cpu/svd.cpp +++ b/libnd4j/include/helpers/cpu/svd.cpp @@ -223,26 +223,26 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh const T almostZero = DataTypeUtils::min(); T maxElem; if(len == 1) - maxElem = math::nd4j_abs(diagInterval->template e(0)); + maxElem = math::nd4j_abs(diagInterval.template e(0)); else - maxElem = (*diagInterval)({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); + maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e(0); T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); T epsBig = (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); - if(diagInterval->template e(0) < epsBig) - diagInterval->p(Nd4jLong(0), epsBig); + if(diagInterval.template e(0) < epsBig) + diagInterval.p(Nd4jLong(0), epsBig); for(int i=1; i < len; ++i) if(math::nd4j_abs(colVec0->template e(i)) < eps) colVec0->p(i, 0.f); for(int i=1; i < len; i++) - if(diagInterval->template e(i) < epsBig) { + if(diagInterval.template e(i) < epsBig) { deflation1(col1, shift, i, len); for(int i = 0; i < len; ++i) - diagInterval->p(i, _m.e(col1+shift+i,col1+shift+i)); + diagInterval.p(i, _m.e(col1+shift+i,col1+shift+i)); } { @@ -261,7 +261,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh int p = 1; for(int i=1; i(diagInterval->template e(i)) < almostZero) + if(math::nd4j_abs(diagInterval.template e(i)) < almostZero) permut[p++] = i; int k = 1, m = ind+1; @@ -271,7 +271,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh permut[p] = m++; else if(m >= len) permut[p] = k++; - else if(diagInterval->template e(k) < diagInterval->template e(m)) + else if(diagInterval.template e(k) < diagInterval.template e(m)) permut[p] = m++; else permut[p] = k++; @@ -281,7 +281,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh if(totDefl) { for(int i=1; i(diagInterval->template e(ki)) < almostZero || diagInterval->template e(0) < diagInterval->template e(ki)) + if(math::nd4j_abs(diagInterval.template e(ki)) < almostZero || diagInterval.template e(0) < diagInterval.template e(ki)) permut[i-1] = permut[i]; else { permut[i-1] = 0; @@ -303,10 +303,10 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh const int ki = permut[len - (totDefl ? i+1 : i)]; const int jac = tCol[ki]; - T _e0 = diagInterval->template e(jac); + T _e0 = diagInterval.template e(jac); //math::nd4j_swap(diagInterval)(i), (*diagInterval)(jac)); - diagInterval->p(jac, diagInterval->template e(i)); - diagInterval->p(i, _e0); + diagInterval.p(jac, diagInterval.template e(i)); + diagInterval.p(i, _e0); if(i!=0 && jac!=0) { _e0 = colVec0->template e(jac); @@ -315,9 +315,8 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh colVec0->p(i, _e0); } - NDArray* temp1 = nullptr, *temp2 = nullptr; if (_calcU) { - auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true); + auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true); auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true); auto temp3 = temp1; temp1.assign(temp2); @@ -352,12 +351,12 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh { int i = len-1; - while(i > 0 && (math::nd4j_abs(diagInterval->template e(i)) < almostZero || math::nd4j_abs(colVec0->template e(i)) < almostZero)) + while(i > 0 && (math::nd4j_abs(diagInterval.template e(i)) < almostZero || math::nd4j_abs(colVec0->template e(i)) < almostZero)) --i; for(; i > 1; --i) { - if( (diagInterval->template e(i) - diagInterval->template e(i-1)) < DataTypeUtils::eps()*maxElem ) { - if (math::nd4j_abs(diagInterval->template e(i) - diagInterval->template e(i-1)) >= epsBig) + if( (diagInterval.template e(i) - diagInterval.template e(i-1)) < DataTypeUtils::eps()*maxElem ) { + if (math::nd4j_abs(diagInterval.template e(i) - diagInterval.template e(i-1)) >= epsBig) throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !"); deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len); } @@ -365,7 +364,6 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh } delete colVec0; - delete diagInterval; } @@ -609,9 +607,7 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA const T almostZero = DataTypeUtils::min(); auto col0 = _m({col1, col1+size, col1, col1+1}, true); - auto diagP = _m({col1, col1+size, col1, col1+size}, true).diagonal('c'); - auto diag = *diagP; - delete diagP; + auto diag = static_cast(_m({col1, col1+size, col1, col1+size}, true).diagonal('c')); diag.p(Nd4jLong(0), T(0)); singVals = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); @@ -730,8 +726,7 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true); temp.assign(0.); auto diag = _m.diagonal('c'); - (*diag)({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); - delete diag; + diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); return; } @@ -762,11 +757,6 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif f.assign(_u({0,1, col1+k+1, col1+n}, true)); } - // UofSVD.printIndexedBuffer(); - // VofSVD.printIndexedBuffer(); - // singVals.printIndexedBuffer(); - // printf("!! \n"); - if (_calcV) _v.p(row1W+k, col1W, 1.f); @@ -789,14 +779,10 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif temp.assign(_u({col1, col1+k+1, i, i+1}, true)); } - auto temp1 = _u({col1,col1+k+1, col1,col1+1}, true); - temp1.assign(q1 * c0); - auto temp2 = _u({col1,col1+k+1, col2+1,col2+2}, true); - temp2.assign(q1 * (-s0)); - auto temp3 = _u({col1+k+1,col1+n+1, col1, col1+1}, true); - temp3.assign(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true) * s0); - auto temp4 =_u({col1+k+1,col1+n+1, col2+1,col2+2}, true); - temp4 *= c0; + _u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0); + _u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0)); + _u({col1+k+1,col1+n+1, col1, col1+1}, true).assign(static_cast(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true)) * s0); + _u({col1+k+1,col1+n+1, col2+1,col2+2}, true) *= c0; } else { @@ -844,8 +830,7 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true); blockM = 0.f; auto diag = blockM.diagonal('c'); - diag->assign(singVals); - delete diag; + diag.assign(singVals); } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h index 53d30abf6..f07cc178c 100644 --- a/libnd4j/include/helpers/cublasHelper.h +++ b/libnd4j/include/helpers/cublasHelper.h @@ -34,12 +34,14 @@ namespace nd4j { std::vector _cache; std::vector _solvers; + std::vector _cudnn; CublasHelper(); ~CublasHelper(); public: static CublasHelper* getInstance(); + void* cudnn(); void* solver(); void* handle(); diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index d191c7803..bf366dc29 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -235,6 +235,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou if(C == nullptr) C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + if (C->isEmpty()) + return C; + const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first(); const auto aType = A->dataType(); @@ -285,17 +288,17 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou bool cNcont = N == 1 || C->strideAt(1) == 1; if(!aMcont && !aKcont) { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); toDelete.push_back(pA); aMcont = true; } if(!bKcont && !bNcont) { - pB = B->dup('f'); + pB = new NDArray(B->dup('f')); toDelete.push_back(pB); bKcont = true; } if(!cMcont) { - pC = C->dup('f'); + pC = new NDArray(C->dup('f')); toDelete.push_back(pC); cMcont = true; } @@ -376,6 +379,9 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* if(Y == nullptr) Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); + if (Y->isEmpty()) + return Y; + const int incx = X->strideAt(xLenDim); const int incy = Y->strideAt(yLenDim); @@ -418,7 +424,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* bool aNcont = N == 1 || A->strideAt(1) == 1; if(!aMcont && !aNcont) { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); aMcont = true; } @@ -634,6 +640,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con else C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + if (C->isEmpty()) + return C; + const int cRank = C->rankOf(); const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); @@ -866,12 +875,12 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, bool cNcont = N == 1 || C->strideAt(-1) == 1; if(!aMcont && !aKcont) { - pA = A->dup('c'); + pA = new NDArray(A->dup('c')); toDelete.push_back(pA); aKcont = true; } if(!bKcont && !bNcont) { - pB = B->dup('c'); + pB = new NDArray(B->dup('c')); toDelete.push_back(pB); bNcont = true; } diff --git a/libnd4j/include/helpers/cuda_off/cublasHelper.cu b/libnd4j/include/helpers/cuda_off/cublasHelper.cu index d9784eaa2..7204862eb 100644 --- a/libnd4j/include/helpers/cuda_off/cublasHelper.cu +++ b/libnd4j/include/helpers/cuda_off/cublasHelper.cu @@ -25,6 +25,13 @@ #include #include #include +#include "config.h" + +#ifdef HAVE_CUDNN + +#include + +#endif namespace nd4j { std::mutex CublasHelper::_mutex; @@ -47,6 +54,18 @@ namespace nd4j { return cusolverH; } + static void* cudnn_() { +#ifdef HAVE_CUDNN + auto cudnnH = new cudnnHandle_t(); + auto status = cudnnCreate(cudnnH); + if (status != CUDNN_STATUS_SUCCESS) + throw cuda_exception::build("cuDNN handle creation failed !", status); + + return cudnnH; +#endif + return nullptr; + } + static void destroyHandle_(void* handle) { auto ch = reinterpret_cast(handle); auto status = cublasDestroy_v2(*ch); @@ -62,11 +81,13 @@ namespace nd4j { auto currentDevice = AffinityManager::currentDeviceId(); _cache.resize(numDevices); _solvers.resize(numDevices); + _cudnn.resize(numDevices); for (int e = 0; e < numDevices; e++) { AffinityManager::setCurrentNativeDevice(e); _cache[e] = handle_(); _solvers[e] = solver_(); + _cudnn[e] = cudnn_(); } // don't forget to restore back original device @@ -90,6 +111,14 @@ namespace nd4j { return _INSTANCE; } + void* CublasHelper::cudnn() { + auto deviceId = AffinityManager::currentDeviceId(); + if (deviceId < 0 || deviceId > _cudnn.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); + + return _cudnn[deviceId]; + } + void* CublasHelper::handle() { auto deviceId = AffinityManager::currentDeviceId(); return handle(deviceId); diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index ab97ad137..716062a53 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -236,6 +236,9 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B, throw std::invalid_argument(""); } + if (z->isEmpty()) + return; + NDArray* xT(const_cast(x)), *yT(const_cast(y)), *zT(z); if((transX && xRank > 1) || (transY && yRank > 1)) { diff --git a/libnd4j/include/helpers/impl/StringUtils.cpp b/libnd4j/include/helpers/impl/StringUtils.cpp index cd0383a75..faace2c63 100644 --- a/libnd4j/include/helpers/impl/StringUtils.cpp +++ b/libnd4j/include/helpers/impl/StringUtils.cpp @@ -19,7 +19,58 @@ // #include +#include namespace nd4j { + static FORCEINLINE bool match(const uint8_t *haystack, const uint8_t *needle, uint64_t length) { + for (int e = 0; e < length; e++) + if (haystack[e] != needle[e]) + return false; + return true; + } + + uint64_t StringUtils::countSubarrays(const void *vhaystack, uint64_t haystackLength, const void *vneedle, uint64_t needleLength) { + auto haystack = reinterpret_cast(vhaystack); + auto needle = reinterpret_cast(vneedle); + + uint64_t number = 0; + + for (uint64_t e = 0; e < haystackLength - needleLength; e++) { + if (match(&haystack[e], needle, needleLength)) + number++; + } + + return number; + } + + + uint64_t StringUtils::byteLength(const NDArray &array) { + if (!array.isS()) + throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType()); + + uint64_t result = 0; + + // our buffer stores offsets, and the last value is basically number of bytes used + auto buffer = array.bufferAsT(); + result = buffer[array.lengthOf()]; + + return result; + } + + std::vector StringUtils::split(const std::string &haystack, const std::string &delimiter) { + std::vector output; + + std::string::size_type prev_pos = 0, pos = 0; + + // iterating through the haystack till the end + while((pos = haystack.find(delimiter, pos)) != std::string::npos) { + output.emplace_back(haystack.substr(prev_pos, pos-prev_pos)); + prev_pos = ++pos; + } + + output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); // Last word + + return output; + } } diff --git a/libnd4j/include/loops/BroadcastPairwiseConverter.h b/libnd4j/include/loops/BroadcastPairwiseConverter.h index fb5acf19b..f1fda4a9a 100644 --- a/libnd4j/include/loops/BroadcastPairwiseConverter.h +++ b/libnd4j/include/loops/BroadcastPairwiseConverter.h @@ -53,6 +53,7 @@ inline pairwise::Ops fromBroadcastToPairwise(broadcast::Ops op) { case broadcast::LogicalXor: return pairwise::LogicalXor; case broadcast::LogicalNot: return pairwise::LogicalNot; case broadcast::LogicalAnd: return pairwise::LogicalAnd; + case broadcast::PowDerivative: return pairwise::PowDerivative; default: throw std::runtime_error("fromBroadcastToPairwise: Not convertible operation"); } diff --git a/libnd4j/include/loops/cpu/broadcast/broadcast_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p0.cpp similarity index 100% rename from libnd4j/include/loops/cpu/broadcast/broadcast_p0.cpp rename to libnd4j/include/loops/cpu/compilation_units/broadcast_p0.cpp diff --git a/libnd4j/include/loops/cpu/broadcast/broadcast_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p1.cpp similarity index 100% rename from libnd4j/include/loops/cpu/broadcast/broadcast_p1.cpp rename to libnd4j/include/loops/cpu/compilation_units/broadcast_p1.cpp diff --git a/libnd4j/include/loops/cpu/broadcast/broadcast_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p2.cpp similarity index 100% rename from libnd4j/include/loops/cpu/broadcast/broadcast_p2.cpp rename to libnd4j/include/loops/cpu/compilation_units/broadcast_p2.cpp diff --git a/libnd4j/include/loops/cpu/broadcast/broadcast_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p3.cpp similarity index 100% rename from libnd4j/include/loops/cpu/broadcast/broadcast_p3.cpp rename to libnd4j/include/loops/cpu/compilation_units/broadcast_p3.cpp diff --git a/libnd4j/include/loops/cpu/broadcast/broadcast_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p4.cpp similarity index 100% rename from libnd4j/include/loops/cpu/broadcast/broadcast_p4.cpp rename to libnd4j/include/loops/cpu/compilation_units/broadcast_p4.cpp diff --git a/libnd4j/include/loops/cpu/broadcast/broadcast_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p5.cpp similarity index 100% rename from libnd4j/include/loops/cpu/broadcast/broadcast_p5.cpp rename to libnd4j/include/loops/cpu/compilation_units/broadcast_p5.cpp diff --git a/libnd4j/include/loops/cpu/broadcast/broadcast_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p6.cpp similarity index 100% rename from libnd4j/include/loops/cpu/broadcast/broadcast_p6.cpp rename to libnd4j/include/loops/cpu/compilation_units/broadcast_p6.cpp diff --git a/libnd4j/include/loops/cpu/broadcast/broadcast_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p7.cpp similarity index 100% rename from libnd4j/include/loops/cpu/broadcast/broadcast_p7.cpp rename to libnd4j/include/loops/cpu/compilation_units/broadcast_p7.cpp diff --git a/libnd4j/include/loops/cpu/broadcast/broadcast_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p8.cpp similarity index 100% rename from libnd4j/include/loops/cpu/broadcast/broadcast_p8.cpp rename to libnd4j/include/loops/cpu/compilation_units/broadcast_p8.cpp diff --git a/libnd4j/include/loops/cpu/broadcast/broadcast_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p9.cpp similarity index 100% rename from libnd4j/include/loops/cpu/broadcast/broadcast_p9.cpp rename to libnd4j/include/loops/cpu/compilation_units/broadcast_p9.cpp diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp new file mode 100644 index 000000000..7b87535c2 --- /dev/null +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../indexreduce.hpp" + +namespace functions { + namespace indexreduce { + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, (nd4j::DataType::INT32, int32_t)); + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp new file mode 100644 index 000000000..d1005699c --- /dev/null +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../indexreduce.hpp" + +namespace functions { + namespace indexreduce { + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, (nd4j::DataType::INT64, Nd4jLong)); + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/pairwise/pairwise_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p0.cpp similarity index 100% rename from libnd4j/include/loops/cpu/pairwise/pairwise_p0.cpp rename to libnd4j/include/loops/cpu/compilation_units/pairwise_p0.cpp diff --git a/libnd4j/include/loops/cpu/pairwise/pairwise_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p1.cpp similarity index 100% rename from libnd4j/include/loops/cpu/pairwise/pairwise_p1.cpp rename to libnd4j/include/loops/cpu/compilation_units/pairwise_p1.cpp diff --git a/libnd4j/include/loops/cpu/pairwise/pairwise_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p2.cpp similarity index 100% rename from libnd4j/include/loops/cpu/pairwise/pairwise_p2.cpp rename to libnd4j/include/loops/cpu/compilation_units/pairwise_p2.cpp diff --git a/libnd4j/include/loops/cpu/pairwise/pairwise_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p3.cpp similarity index 100% rename from libnd4j/include/loops/cpu/pairwise/pairwise_p3.cpp rename to libnd4j/include/loops/cpu/compilation_units/pairwise_p3.cpp diff --git a/libnd4j/include/loops/cpu/pairwise/pairwise_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p4.cpp similarity index 100% rename from libnd4j/include/loops/cpu/pairwise/pairwise_p4.cpp rename to libnd4j/include/loops/cpu/compilation_units/pairwise_p4.cpp diff --git a/libnd4j/include/loops/cpu/pairwise/pairwise_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p5.cpp similarity index 100% rename from libnd4j/include/loops/cpu/pairwise/pairwise_p5.cpp rename to libnd4j/include/loops/cpu/compilation_units/pairwise_p5.cpp diff --git a/libnd4j/include/loops/cpu/pairwise/pairwise_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p6.cpp similarity index 100% rename from libnd4j/include/loops/cpu/pairwise/pairwise_p6.cpp rename to libnd4j/include/loops/cpu/compilation_units/pairwise_p6.cpp diff --git a/libnd4j/include/loops/cpu/pairwise/pairwise_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p7.cpp similarity index 100% rename from libnd4j/include/loops/cpu/pairwise/pairwise_p7.cpp rename to libnd4j/include/loops/cpu/compilation_units/pairwise_p7.cpp diff --git a/libnd4j/include/loops/cpu/pairwise/pairwise_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p8.cpp similarity index 100% rename from libnd4j/include/loops/cpu/pairwise/pairwise_p8.cpp rename to libnd4j/include/loops/cpu/compilation_units/pairwise_p8.cpp diff --git a/libnd4j/include/loops/cpu/pairwise/pairwise_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p9.cpp similarity index 100% rename from libnd4j/include/loops/cpu/pairwise/pairwise_p9.cpp rename to libnd4j/include/loops/cpu/compilation_units/pairwise_p9.cpp diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp new file mode 100644 index 000000000..8df61ad29 --- /dev/null +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../reduce3.hpp" + +namespace functions { + namespace reduce3 { + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_3); + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp new file mode 100644 index 000000000..10e78e914 --- /dev/null +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../reduce3.hpp" + +namespace functions { + namespace reduce3 { + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_2); + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp new file mode 100644 index 000000000..5362352b6 --- /dev/null +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../reduce3.hpp" + +namespace functions { + namespace reduce3 { + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_1); + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp new file mode 100644 index 000000000..8a738acf9 --- /dev/null +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../reduce3.hpp" + +namespace functions { + namespace reduce3 { + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_0); + } +} \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/dataset/DataSetLoader.java b/libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp old mode 100755 new mode 100644 similarity index 73% rename from deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/dataset/DataSetLoader.java rename to libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp index a61a5da69..de4619f29 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/dataset/DataSetLoader.java +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,21 +15,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.aws.dataset; +// +// @author raver119@gmail.com +// -import org.deeplearning4j.aws.s3.reader.S3Downloader; - -import java.io.InputStream; - -public class DataSetLoader { - - private String bucket; - - - - public void onData(InputStream is) { - S3Downloader downloader = new S3Downloader(); +#include "../reduce/reduce_float.hpp" +namespace functions { + namespace reduce { + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_0); } - } diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/DistributedDeepLearningTrainer.java b/libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp old mode 100755 new mode 100644 similarity index 73% rename from deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/DistributedDeepLearningTrainer.java rename to libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp index 846d8b200..bfa88bc3b --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/DistributedDeepLearningTrainer.java +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,18 +15,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.aws.ec2.provision; +// +// @author raver119@gmail.com +// -public class DistributedDeepLearningTrainer { - - private DistributedDeepLearningTrainer() {} - - /** - * @param args - */ - public static void main(String[] args) { - ClusterSetup clusterSet = new ClusterSetup(args); +#include "../reduce/reduce_float.hpp" +namespace functions { + namespace reduce { + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_1); } - } diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecMarshaller.java b/libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp similarity index 69% rename from datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecMarshaller.java rename to libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp index 62eba29d7..8cc2795a4 100644 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecMarshaller.java +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,23 +15,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.datavec.camel.component; +// +// @author raver119@gmail.com +// -import org.apache.camel.Exchange; -import org.datavec.api.split.InputSplit; - -/** - * Marshals na exchange in to an input split - * @author Adam Gibson - */ -public interface DataVecMarshaller { - - - /** - * - * @param exchange - * @return - */ - InputSplit getSplit(Exchange exchange); +#include "../reduce/reduce_float.hpp" +namespace functions { + namespace reduce { + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_2); + } } diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp new file mode 100644 index 000000000..0b94831c3 --- /dev/null +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../reduce/reduce_float.hpp" + +namespace functions { + namespace reduce { + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_3); + } +} diff --git a/libnd4j/include/loops/cpu/scalar/scalar_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p0.cpp similarity index 100% rename from libnd4j/include/loops/cpu/scalar/scalar_p0.cpp rename to libnd4j/include/loops/cpu/compilation_units/scalar_p0.cpp diff --git a/libnd4j/include/loops/cpu/scalar/scalar_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p1.cpp similarity index 100% rename from libnd4j/include/loops/cpu/scalar/scalar_p1.cpp rename to libnd4j/include/loops/cpu/compilation_units/scalar_p1.cpp diff --git a/libnd4j/include/loops/cpu/scalar/scalar_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p2.cpp similarity index 100% rename from libnd4j/include/loops/cpu/scalar/scalar_p2.cpp rename to libnd4j/include/loops/cpu/compilation_units/scalar_p2.cpp diff --git a/libnd4j/include/loops/cpu/scalar/scalar_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p3.cpp similarity index 100% rename from libnd4j/include/loops/cpu/scalar/scalar_p3.cpp rename to libnd4j/include/loops/cpu/compilation_units/scalar_p3.cpp diff --git a/libnd4j/include/loops/cpu/scalar/scalar_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p4.cpp similarity index 100% rename from libnd4j/include/loops/cpu/scalar/scalar_p4.cpp rename to libnd4j/include/loops/cpu/compilation_units/scalar_p4.cpp diff --git a/libnd4j/include/loops/cpu/scalar/scalar_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p5.cpp similarity index 100% rename from libnd4j/include/loops/cpu/scalar/scalar_p5.cpp rename to libnd4j/include/loops/cpu/compilation_units/scalar_p5.cpp diff --git a/libnd4j/include/loops/cpu/scalar/scalar_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p6.cpp similarity index 100% rename from libnd4j/include/loops/cpu/scalar/scalar_p6.cpp rename to libnd4j/include/loops/cpu/compilation_units/scalar_p6.cpp diff --git a/libnd4j/include/loops/cpu/scalar/scalar_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p7.cpp similarity index 100% rename from libnd4j/include/loops/cpu/scalar/scalar_p7.cpp rename to libnd4j/include/loops/cpu/compilation_units/scalar_p7.cpp diff --git a/libnd4j/include/loops/cpu/scalar/scalar_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p8.cpp similarity index 100% rename from libnd4j/include/loops/cpu/scalar/scalar_p8.cpp rename to libnd4j/include/loops/cpu/compilation_units/scalar_p8.cpp diff --git a/libnd4j/include/loops/cpu/scalar/scalar_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p9.cpp similarity index 100% rename from libnd4j/include/loops/cpu/scalar/scalar_p9.cpp rename to libnd4j/include/loops/cpu/compilation_units/scalar_p9.cpp diff --git a/libnd4j/include/loops/cpu/indexreduce.cpp b/libnd4j/include/loops/cpu/indexreduce.hpp similarity index 98% rename from libnd4j/include/loops/cpu/indexreduce.cpp rename to libnd4j/include/loops/cpu/indexreduce.hpp index df3fd64a9..829f60a18 100644 --- a/libnd4j/include/loops/cpu/indexreduce.cpp +++ b/libnd4j/include/loops/cpu/indexreduce.hpp @@ -151,8 +151,5 @@ void IndexReduce::exec(void *vx, Nd4jLong *xShapeInfo, nd4j::IndexReductionLoops::template loopIndexReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); } - -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); - } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp index 882b1740e..1ee820853 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp @@ -20,7 +20,6 @@ // #include -#include #include #include #include diff --git a/libnd4j/include/loops/cpu/reduce/reduce_float.cpp b/libnd4j/include/loops/cpu/reduce/reduce_float.hpp similarity index 98% rename from libnd4j/include/loops/cpu/reduce/reduce_float.cpp rename to libnd4j/include/loops/cpu/reduce/reduce_float.hpp index 112656852..d0a80a3f5 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_float.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_float.hpp @@ -20,7 +20,6 @@ // #include -#include #include #include #include @@ -269,8 +268,5 @@ namespace functions { // return result return OpType::postProcess(intermediate[0], length, extraParams); } - - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp index 76dc209f6..e53c9ac8e 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp @@ -20,7 +20,6 @@ // #include -#include #include #include #include diff --git a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp index cbd7e6e12..929d9c4ff 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp @@ -20,7 +20,6 @@ // #include -#include #include #include #include diff --git a/libnd4j/include/loops/cpu/reduce3.cpp b/libnd4j/include/loops/cpu/reduce3.hpp similarity index 99% rename from libnd4j/include/loops/cpu/reduce3.cpp rename to libnd4j/include/loops/cpu/reduce3.hpp index dbe93620a..8d50aedbc 100644 --- a/libnd4j/include/loops/cpu/reduce3.cpp +++ b/libnd4j/include/loops/cpu/reduce3.hpp @@ -254,10 +254,5 @@ void Reduce3::execAll(const int opNum, DISPATCH_BY_OPNUM_TT(execAll, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), REDUCE3_OPS); } - - - -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES); - } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/indexreduce.cu b/libnd4j/include/loops/cuda/indexreduce.cu index 1bd5d10cb..aeb2d9d36 100644 --- a/libnd4j/include/loops/cuda/indexreduce.cu +++ b/libnd4j/include/loops/cuda/indexreduce.cu @@ -35,12 +35,12 @@ static __global__ void simpleIndexReduceGeneric(const int op, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, - Nd4jLong *resultShapeInfo, int zRank, + Nd4jLong *zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { - functions::indexreduce::IndexReduce::transform(op,dx,xShapeInfo,extraParams,result,resultShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); + functions::indexreduce::IndexReduce::transform(op,dx,xShapeInfo,extraParams,result,zShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); } namespace functions { @@ -52,7 +52,7 @@ namespace functions { void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, - void *result, Nd4jLong *resultShapeInfo, + void *result, Nd4jLong *zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, @@ -62,7 +62,7 @@ namespace functions { simpleIndexReduceGeneric<<>>(opNum, dx, xShapeInfo, xRank, extraParams, - result, resultShapeInfo, 0, + result, zShapeInfo, 0, nullptr, 0, 1, allocationBuffer, reductionBuffer, @@ -70,14 +70,14 @@ namespace functions { } template - _CUDA_H void IndexReduce::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { + _CUDA_H void IndexReduce::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { simpleIndexReduceGeneric<<>>( opNum, dx, xShapeInfo, xRank, extraParams, result, - resultShapeInfo, zRank, + zShapeInfo, zRank, dimension, dimensionLength, 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets); @@ -158,7 +158,7 @@ namespace functions { Nd4jLong *xShapeInfo, void *extraParams, void *result, - Nd4jLong *resultShapeInfo, + Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, @@ -166,7 +166,7 @@ namespace functions { void *reductionBuffer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, resultShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); + DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); } @@ -174,7 +174,7 @@ namespace functions { template __device__ void IndexReduce::transform(void *vdx, Nd4jLong *xShapeInfo, void *vextraParams, - void *vresult, Nd4jLong *resultShapeInfo, + void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *vreductionBuffer, @@ -183,7 +183,7 @@ namespace functions { * Gpu information for the problem */ auto dx = reinterpret_cast(vdx); - auto result = reinterpret_cast(vresult); + auto z = reinterpret_cast(vz); auto extraParams = static_cast(vextraParams); auto reductionBuffer = static_cast(vreductionBuffer); auto order = shape::order(xShapeInfo); @@ -203,19 +203,19 @@ namespace functions { //length for the tad __shared__ volatile Nd4jLong xLength; - __shared__ volatile Nd4jLong resultLength; + __shared__ volatile Nd4jLong zLen; //only compute the tad indexes once IndexValue reduction = OpType::startingIndexValue(dx); if (threadIdx.x == 0) { - if (resultShapeInfo != nullptr) - resultLength = shape::length(resultShapeInfo); - else resultLength = 1; + if (zShapeInfo != nullptr) + zLen = shape::length(zShapeInfo); + else zLen = 1; if (dimensionLength == 1) { - if (resultLength == 1 && (dimension == nullptr || dimension[0] == MAX_DIMENSION)) + if (zLen == 1 && (dimension == nullptr || dimension[0] == MAX_DIMENSION)) resultScalar = 1; else resultScalar = 0; @@ -223,13 +223,24 @@ namespace functions { else resultScalar = 0; - if (resultLength == 1) + if (zLen == 1) resultScalar = 1; xLength = shape::length(xShapeInfo); } __syncthreads(); + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + + for (uint i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) + z[i] = (Z) reduction.index; + + return; + } + if (!resultScalar) { __shared__ Nd4jLong tadLength; @@ -261,7 +272,7 @@ namespace functions { __syncthreads(); if (threadIdx.x == 0) { - result[r] = (Z) sPartials[threadIdx.x].index; + z[r] = (Z) sPartials[threadIdx.x].index; } __syncthreads(); } @@ -282,7 +293,7 @@ namespace functions { __syncthreads(); if (threadIdx.x == 0) { - result[i] = (Z) sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams); + z[i] = (Z) sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams); } __syncthreads(); } @@ -345,14 +356,14 @@ namespace functions { __syncthreads(); if (tid == 0) { - result[0] = (Z) sPartials[0].index; + z[0] = (Z) sPartials[0].index; } } } else { if (tid == 0) { auto tc = reinterpret_cast(reductionBuffer); tc[16384] = 0; - result[0] = (Z) sPartials[0].index; + z[0] = (Z) sPartials[0].index; } } diff --git a/libnd4j/include/loops/impl/type_conversions.cpp b/libnd4j/include/loops/impl/type_conversions.cpp index 5a4a9db41..b12ff5796 100644 --- a/libnd4j/include/loops/impl/type_conversions.cpp +++ b/libnd4j/include/loops/impl/type_conversions.cpp @@ -82,7 +82,7 @@ namespace nd4j { // now we actually apply quantization auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - rz[e] = static_cast(nd4j::math::nd4j_round(1.0f * x[e] / nd4j::math::nd4j_max(amax, amin) * max_byte)); + rz[e] = static_cast(nd4j::math::nd4j_round( 1.0f * static_cast(x[e]) / nd4j::math::nd4j_max(amax, amin) * max_byte)); } }; @@ -180,7 +180,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write) for (auto e = start; e < stop; e += increment) { int el = x[e]; int ael = nd4j::math::nd4j_abs(el) - 1; - z[ael] += el > 0 ? threshold : -threshold; + z[ael] += el > 0 ? static_cast(threshold) : static_cast(-threshold); } }; diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 7de54a858..ea32b154c 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -80,7 +80,8 @@ (30, LogicalAnd), \ (31, DivideNoNan), \ (32, IGamma), \ - (33, IGammac) + (33, IGammac),\ + (34, PowDerivative) // these ops return same data type as input #define TRANSFORM_SAME_OPS \ diff --git a/libnd4j/include/memory/MemoryCounter.h b/libnd4j/include/memory/MemoryCounter.h new file mode 100644 index 000000000..bf8ff60dc --- /dev/null +++ b/libnd4j/include/memory/MemoryCounter.h @@ -0,0 +1,146 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_MEMORYCOUNTER_H +#define SD_MEMORYCOUNTER_H + +#include +#include +#include +#include +#include + +namespace nd4j { + namespace memory { + /** + * This class provides simple per-device counter + */ + class ND4J_EXPORT MemoryCounter { + private: + static MemoryCounter* _INSTANCE; + + // used for synchronization + std::mutex _locker; + + // per-device counters + std::map _deviceCounters; + + // TODO: change this wrt heterogenous stuff on next iteration + // per-group counters + std::map _groupCounters; + + // per-device limits + std::map _deviceLimits; + + // per-group limits + std::map _groupLimits; + + MemoryCounter(); + ~MemoryCounter() = default; + + public: + static MemoryCounter *getInstance(); + + /** + * This method checks if allocation of numBytes won't break through per-group or per-device limit + * @param numBytes + * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise + */ + bool validate(Nd4jLong numBytes); + + /** + * This method checks if allocation of numBytes won't break through per-device limit + * @param deviceId + * @param numBytes + * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise + */ + bool validateDevice(int deviceId, Nd4jLong numBytes); + + /** + * This method checks if allocation of numBytes won't break through per-group limit + * @param deviceId + * @param numBytes + * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise + */ + bool validateGroup(nd4j::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method adds specified number of bytes to specified counter + * @param deviceId + * @param numBytes + */ + void countIn(int deviceId, Nd4jLong numBytes); + void countIn(nd4j::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method subtracts specified number of bytes from specified counter + * @param deviceId + * @param numBytes + */ + void countOut(int deviceId, Nd4jLong numBytes); + void countOut(nd4j::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method returns amount of memory allocated on specified device + * @param deviceId + * @return + */ + Nd4jLong allocatedDevice(int deviceId); + + /** + * This method returns amount of memory allocated in specified group of devices + * @param group + * @return + */ + Nd4jLong allocatedGroup(nd4j::memory::MemoryType group); + + /** + * This method allows to set per-device memory limits + * @param deviceId + * @param numBytes + */ + void setDeviceLimit(int deviceId, Nd4jLong numBytes); + + /** + * This method returns current device limit in bytes + * @param deviceId + * @return + */ + Nd4jLong deviceLimit(int deviceId); + + /** + * This method allows to set per-group memory limits + * @param group + * @param numBytes + */ + void setGroupLimit(nd4j::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method returns current group limit in bytes + * @param group + * @return + */ + Nd4jLong groupLimit(nd4j::memory::MemoryType group); + }; + } +} + + +#endif //SD_MEMORYCOUNTER_H diff --git a/libnd4j/include/memory/MemoryTracker.h b/libnd4j/include/memory/MemoryTracker.h index 78ade5bcc..097d2903d 100644 --- a/libnd4j/include/memory/MemoryTracker.h +++ b/libnd4j/include/memory/MemoryTracker.h @@ -30,6 +30,9 @@ namespace nd4j { namespace memory { + /** + * This class is used for tracking memory allocation wrt their allocation points in code + */ class ND4J_EXPORT MemoryTracker { private: static MemoryTracker* _INSTANCE; diff --git a/libnd4j/include/memory/cuda/Workspace.cu b/libnd4j/include/memory/cuda/Workspace.cu index 18b5ebf3b..aeb6b4752 100644 --- a/libnd4j/include/memory/cuda/Workspace.cu +++ b/libnd4j/include/memory/cuda/Workspace.cu @@ -143,7 +143,7 @@ namespace nd4j { cudaFreeHost((void *)this->_ptrHost); if (this->_allocatedDevice && !_externalized) - cudaFree((void *)this->_ptrHost); + cudaFree((void *)this->_ptrDevice); freeSpills(); } diff --git a/libnd4j/include/memory/impl/MemoryCounter.cpp b/libnd4j/include/memory/impl/MemoryCounter.cpp new file mode 100644 index 000000000..0dc845e37 --- /dev/null +++ b/libnd4j/include/memory/impl/MemoryCounter.cpp @@ -0,0 +1,133 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../MemoryCounter.h" +#include +#include +#include + +namespace nd4j { + namespace memory { + + MemoryCounter::MemoryCounter() { + auto numDevices = nd4j::AffinityManager::numberOfDevices(); + + // setting default 0s + for (int e = 0; e < numDevices; e++) { + _deviceLimits[e] = 0; + _deviceCounters[e] = 0; + } + + // setting initial values for limits + _groupLimits[nd4j::memory::MemoryType::HOST] = nd4j::Environment::getInstance()->maxPrimaryMemory(); + _groupLimits[nd4j::memory::MemoryType::DEVICE] = nd4j::Environment::getInstance()->maxSpecialMemory(); + + // setting initial counter values + _groupCounters[nd4j::memory::MemoryType::HOST] = 0; + _groupCounters[nd4j::memory::MemoryType::DEVICE] = 0; + } + + MemoryCounter* MemoryCounter::getInstance() { + if (_INSTANCE == 0) + _INSTANCE = new MemoryCounter(); + + return _INSTANCE; + } + + void MemoryCounter::countIn(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _deviceCounters[deviceId] += numBytes; + } + + void MemoryCounter::countIn(nd4j::memory::MemoryType group, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _groupCounters[group] += numBytes; + } + + void MemoryCounter::countOut(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _deviceCounters[deviceId] -= numBytes; + } + + void MemoryCounter::countOut(nd4j::memory::MemoryType group, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _groupCounters[group] -= numBytes; + } + + bool MemoryCounter::validate(Nd4jLong numBytes) { + auto deviceId = nd4j::AffinityManager::currentDeviceId(); + return validateDevice(deviceId, numBytes); + } + + bool MemoryCounter::validateDevice(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + auto dLimit = _deviceLimits[deviceId]; + if (dLimit <= 0) + return true; + + auto dAlloc = _deviceCounters[deviceId]; + + return numBytes + dAlloc <= dLimit; + } + + bool MemoryCounter::validateGroup(nd4j::memory::MemoryType group, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + auto gLimit = _groupLimits[group]; + if (gLimit <= 0) + return true; + + auto gAlloc = _groupCounters[group]; + + return numBytes + gAlloc <= gLimit; + } + + Nd4jLong MemoryCounter::allocatedDevice(int deviceId) { + std::lock_guard lock(_locker); + return _deviceCounters[deviceId]; + } + + Nd4jLong MemoryCounter::allocatedGroup(nd4j::memory::MemoryType group) { + std::lock_guard lock(_locker); + return _groupCounters[group]; + } + + void MemoryCounter::setDeviceLimit(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _deviceLimits[deviceId] = numBytes; + } + + void MemoryCounter::setGroupLimit(nd4j::memory::MemoryType group, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _groupLimits[group] = numBytes; + } + + Nd4jLong MemoryCounter::deviceLimit(int deviceId) { + std::lock_guard lock(_locker); + return _deviceLimits[deviceId]; + } + + Nd4jLong MemoryCounter::groupLimit(nd4j::memory::MemoryType group) { + std::lock_guard lock(_locker); + return _groupLimits[group]; + } + + MemoryCounter* MemoryCounter::_INSTANCE = 0; + } +} \ No newline at end of file diff --git a/libnd4j/include/op_boilerplate.h b/libnd4j/include/op_boilerplate.h index 4a6561f3b..97f33569b 100644 --- a/libnd4j/include/op_boilerplate.h +++ b/libnd4j/include/op_boilerplate.h @@ -1242,7 +1242,9 @@ #if defined(_MSC_VER) || defined(_WIN64) || defined(_WIN32) || defined(__CLION_IDE__) || defined(__VSCODE__) #define NOT_EXCLUDED(NAME) 1>0 #else -#define NOT_EXCLUDED(NAME) defined(LIBND4J_ALL_OPS) || defined(NAME) +// for now we don't want minifier mechanics working +//#define NOT_EXCLUDED(NAME) defined(LIBND4J_ALL_OPS) || defined(NAME) +#define NOT_EXCLUDED(NAME) 1>0 #endif #ifdef __JAVACPP_HACK__ @@ -1622,4 +1624,9 @@ #define PARAMETRIC_D() [&] (Parameters &p) -> Context* + +#ifdef __CUDABLAS__ +#define checkCudaErrors(ERR) if (ERR != 0) {throw std::runtime_error("CUDA stream synchronization failed");} +#endif + #endif diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index 256e37341..1bcd2df8b 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -52,6 +52,9 @@ namespace nd4j { static BroadcastOpsTuple Subtract(); static BroadcastOpsTuple IGamma(); static BroadcastOpsTuple IGammac(); + + static BroadcastOpsTuple Pow(); + static BroadcastOpsTuple PowDerivative(); }; } diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index 5aea215c1..0b0e42809 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -40,7 +40,11 @@ #include #include #include +#include +#include +#include #include +#include #include #include #include diff --git a/libnd4j/include/ops/declarable/OpRegistrator.h b/libnd4j/include/ops/declarable/OpRegistrator.h index effb71c67..789b361f3 100644 --- a/libnd4j/include/ops/declarable/OpRegistrator.h +++ b/libnd4j/include/ops/declarable/OpRegistrator.h @@ -23,10 +23,11 @@ #include #include -#include +#include #include #include #include +#include // handlers part #include @@ -66,8 +67,8 @@ namespace nd4j { std::vector _uniqueD; // pointers to platform-specific helpers - std::map _helpersLH; - std::map _helpersH; + std::map, nd4j::ops::platforms::PlatformHelper*> _helpersLH; + std::map, nd4j::ops::platforms::PlatformHelper*> _helpersH; std::vector _uniqueH; std::mutex _locker; @@ -98,13 +99,13 @@ namespace nd4j { void registerHelper(nd4j::ops::platforms::PlatformHelper* op); - bool hasHelper(Nd4jLong hash); + bool hasHelper(Nd4jLong hash, samediff::Engine engine); nd4j::ops::DeclarableOp* getOperation(const char *name); nd4j::ops::DeclarableOp* getOperation(Nd4jLong hash); nd4j::ops::DeclarableOp* getOperation(std::string &name); - nd4j::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash); + nd4j::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, samediff::Engine engine); std::vector getAllHashes(); diff --git a/libnd4j/include/ops/declarable/PlatformHelper.h b/libnd4j/include/ops/declarable/PlatformHelper.h index 6fbbae3b8..afa0107fc 100644 --- a/libnd4j/include/ops/declarable/PlatformHelper.h +++ b/libnd4j/include/ops/declarable/PlatformHelper.h @@ -22,6 +22,7 @@ #define SD_PLATFORMHELPER_H #include +#include #include #include #include @@ -35,18 +36,23 @@ namespace nd4j { */ class ND4J_EXPORT PlatformHelper { protected: + // target engine for this impl + samediff::Engine _engine; + // name of the operation this helper is built for std::string _name; // hash of the operation this helper is built for Nd4jLong _hash; public: - PlatformHelper(const char *name); + PlatformHelper(const char *name, samediff::Engine engine); ~PlatformHelper() = default; std::string name(); + samediff::Engine engine(); + Nd4jLong hash(); /** diff --git a/libnd4j/include/ops/declarable/generic/activations/crelu.cpp b/libnd4j/include/ops/declarable/generic/activations/crelu.cpp index 42b171226..8ce3cbf75 100644 --- a/libnd4j/include/ops/declarable/generic/activations/crelu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/crelu.cpp @@ -32,21 +32,19 @@ namespace nd4j { REQUIRE_TRUE(x->isR(), 0, "CRELU: input must be real type"); auto tmp = x->dup(); - tmp->applyTransform(nd4j::transform::Neg, nullptr, nullptr); + tmp.applyTransform(nd4j::transform::Neg, tmp); auto z = OUTPUT_VARIABLE(0); - helpers::concat(block.launchContext(), {x, tmp}, *z, x->rankOf()-1); + helpers::concat(block.launchContext(), {x, &tmp}, *z, x->rankOf()-1); // NDArrayFactory::concat({x, tmp}, -1, z); // TODO: make this configurable? double threshold = 0.0; - z->applyScalar(nd4j::scalar::RELU, threshold); + z->applyScalar(nd4j::scalar::RELU, threshold, *z); STORE_RESULT(z); - delete tmp; - return Status::OK(); } @@ -61,7 +59,7 @@ namespace nd4j { std::vector shape; for (int e = 0; e < shape::rank(inShape); e++) shape.emplace_back(shape::shapeOf(inShape)[e]); - + shape[shape.size()-1] *= 2; auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape); @@ -94,7 +92,7 @@ namespace nd4j { auto pos = dec->at(0); auto neg = dec->at(1); - pos->applyPairwiseTransform(nd4j::pairwise::Subtract, neg, epsilon, nullptr); + pos->applyPairwiseTransform(nd4j::pairwise::Subtract, *neg, *epsilon); delete tmpResult; delete dec; diff --git a/libnd4j/include/ops/declarable/generic/activations/cube.cpp b/libnd4j/include/ops/declarable/generic/activations/cube.cpp index 075da4b00..75a33ab79 100644 --- a/libnd4j/include/ops/declarable/generic/activations/cube.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/cube.cpp @@ -31,9 +31,9 @@ namespace nd4j { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::Cube, output, nullptr); + input->applyTransform(nd4j::transform::Cube, *output); STORE_RESULT(output); - + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/elu.cpp b/libnd4j/include/ops/declarable/generic/activations/elu.cpp index 03670ddab..85becd858 100644 --- a/libnd4j/include/ops/declarable/generic/activations/elu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/elu.cpp @@ -32,7 +32,7 @@ namespace nd4j { const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; - input->applyScalar(nd4j::scalar::ELU, alpha, output); + input->applyScalar(nd4j::scalar::ELU, alpha, *output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/hardsigmoid.cpp b/libnd4j/include/ops/declarable/generic/activations/hardsigmoid.cpp index 40a98575a..d8b937a0a 100644 --- a/libnd4j/include/ops/declarable/generic/activations/hardsigmoid.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/hardsigmoid.cpp @@ -30,9 +30,9 @@ namespace nd4j { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::HardSigmoid, output, nullptr); + input->applyTransform(nd4j::transform::HardSigmoid, *output); STORE_RESULT(output); - + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/hardtanh.cpp b/libnd4j/include/ops/declarable/generic/activations/hardtanh.cpp index 287dcc113..a4d9fe4e6 100644 --- a/libnd4j/include/ops/declarable/generic/activations/hardtanh.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/hardtanh.cpp @@ -30,9 +30,9 @@ namespace nd4j { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::HardTanh, output, nullptr); + input->applyTransform(nd4j::transform::HardTanh, *output); STORE_RESULT(output); - + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/identity.cpp b/libnd4j/include/ops/declarable/generic/activations/identity.cpp index f65448d92..5ae5b0690 100644 --- a/libnd4j/include/ops/declarable/generic/activations/identity.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/identity.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto z = this->getZ(block); // just for lulz - first->applyTransform(nd4j::transform::Identity, z, nullptr); + first->applyTransform(nd4j::transform::Identity, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/identity_n.cpp b/libnd4j/include/ops/declarable/generic/activations/identity_n.cpp index 0bb47e4b4..b96ab9a3f 100644 --- a/libnd4j/include/ops/declarable/generic/activations/identity_n.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/identity_n.cpp @@ -33,7 +33,7 @@ namespace nd4j { auto x = INPUT_VARIABLE(i); auto z = OUTPUT_VARIABLE(i); - x->applyTransform(transform::Identity, z, nullptr); + x->applyTransform(transform::Identity, *z); } } diff --git a/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp b/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp index ef65c4822..80404135f 100644 --- a/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp @@ -31,7 +31,7 @@ namespace nd4j { float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; - input->applyScalar(nd4j::scalar::LeakyRELU, alpha, output); + input->applyScalar(nd4j::scalar::LeakyRELU, alpha, *output); STORE_RESULT(output); return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/activations/rationaltanh.cpp b/libnd4j/include/ops/declarable/generic/activations/rationaltanh.cpp index 7e85ab9bf..5bae4d2dc 100644 --- a/libnd4j/include/ops/declarable/generic/activations/rationaltanh.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/rationaltanh.cpp @@ -30,9 +30,9 @@ namespace nd4j { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::RationalTanh, output, nullptr); + input->applyTransform(nd4j::transform::RationalTanh, *output); STORE_RESULT(output); - + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/rectifiedtanh.cpp b/libnd4j/include/ops/declarable/generic/activations/rectifiedtanh.cpp index 69d5faa2a..40738c343 100644 --- a/libnd4j/include/ops/declarable/generic/activations/rectifiedtanh.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/rectifiedtanh.cpp @@ -30,9 +30,9 @@ namespace nd4j { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::RectifiedTanh, output, nullptr); + input->applyTransform(nd4j::transform::RectifiedTanh, *output); STORE_RESULT(output); - + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/relu.cpp b/libnd4j/include/ops/declarable/generic/activations/relu.cpp index 3b556ef1f..2c8b978ff 100644 --- a/libnd4j/include/ops/declarable/generic/activations/relu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/relu.cpp @@ -32,7 +32,7 @@ namespace nd4j { auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; - first->applyScalar(nd4j::scalar::RELU, scalar, z); + first->applyScalar(nd4j::scalar::RELU, scalar, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/relu6.cpp b/libnd4j/include/ops/declarable/generic/activations/relu6.cpp index a6861b3f7..cf12d1592 100644 --- a/libnd4j/include/ops/declarable/generic/activations/relu6.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/relu6.cpp @@ -33,8 +33,8 @@ CONFIGURABLE_OP_IMPL(relu6, 1, 1, true, 1, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyScalar(nd4j::scalar::RELU6, T_ARG(0), output); - + input->applyScalar(nd4j::scalar::RELU6, T_ARG(0), *output); + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/selu.cpp b/libnd4j/include/ops/declarable/generic/activations/selu.cpp index 20ac42db2..ca16f6832 100644 --- a/libnd4j/include/ops/declarable/generic/activations/selu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/selu.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(nd4j::transform::SELU, z, nullptr); + first->applyTransform(nd4j::transform::SELU, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/sigmoid.cpp b/libnd4j/include/ops/declarable/generic/activations/sigmoid.cpp index d6f341298..fb8e507a7 100644 --- a/libnd4j/include/ops/declarable/generic/activations/sigmoid.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/sigmoid.cpp @@ -29,7 +29,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(nd4j::transform::Sigmoid, z, nullptr); + first->applyTransform(nd4j::transform::Sigmoid, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/softplus.cpp b/libnd4j/include/ops/declarable/generic/activations/softplus.cpp index 7b3ba74f2..bd538ab71 100644 --- a/libnd4j/include/ops/declarable/generic/activations/softplus.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/softplus.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(nd4j::transform::SoftPlus, z, nullptr); + first->applyTransform(nd4j::transform::SoftPlus, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/softsign.cpp b/libnd4j/include/ops/declarable/generic/activations/softsign.cpp index 50ce3a817..99e52ab68 100644 --- a/libnd4j/include/ops/declarable/generic/activations/softsign.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/softsign.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(nd4j::transform::SoftSign, z, nullptr); + first->applyTransform(nd4j::transform::SoftSign, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/tanh.cpp b/libnd4j/include/ops/declarable/generic/activations/tanh.cpp index b27d07806..5677da728 100644 --- a/libnd4j/include/ops/declarable/generic/activations/tanh.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/tanh.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(nd4j::transform::Tanh, z, nullptr); + first->applyTransform(nd4j::transform::Tanh, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp index 52d01429f..6eb3728ed 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp index b8469d83a..4683e3f3e 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp index f7f3f479a..1d79a84f3 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp index 89d380d02..7a2c61c95 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp index f18314910..0a1c3d5c8 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp index 36b0defd0..0543cc72d 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp index ab4ed9880..4f0fec82d 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/blas/axpy.cpp b/libnd4j/include/ops/declarable/generic/blas/axpy.cpp index 1b949eb35..65d20589f 100644 --- a/libnd4j/include/ops/declarable/generic/blas/axpy.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/axpy.cpp @@ -37,14 +37,14 @@ namespace nd4j { if (block.width() > 2) { auto alpha = INPUT_VARIABLE(2); - REQUIRE_TRUE(alpha->isScalar(), 0, "Axpy: alpha argument should be scalar or TArg"); + REQUIRE_TRUE(alpha->isScalar(), 0, "Axpy: alpha argument should be scalar or TArg"); } else if (block.getTArguments()->size() > 0) { a = T_ARG(0); } ExtraArguments arguments({a}); - y->applyPairwiseTransform(pairwise::Axpy, x, z, &arguments); + y->applyPairwiseTransform(pairwise::Axpy, *x, *z, &arguments); return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/generic/blas/svd.cpp b/libnd4j/include/ops/declarable/generic/blas/svd.cpp index d62c621dd..8db2c2ff3 100644 --- a/libnd4j/include/ops/declarable/generic/blas/svd.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/svd.cpp @@ -33,8 +33,12 @@ CUSTOM_OP_IMPL(svd, 1, 1, false, 0, 3) { const int rank = x->rankOf(); REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); - const bool fullUV = (bool)INT_ARG(0); + bool fullUV = (bool)INT_ARG(0); const bool calcUV = (bool)INT_ARG(1); + + if(calcUV == false) + fullUV = false; + const int switchNum = INT_ARG(2); // #ifndef __CUDABLAS__ diff --git a/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp b/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp index 651e21aab..83cbc9004 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp @@ -29,7 +29,7 @@ namespace nd4j { auto x = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - x->applyTransform(transform::Not, z, nullptr); + x->applyTransform(transform::Not, *z); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/boolean/select.cpp b/libnd4j/include/ops/declarable/generic/boolean/select.cpp index 56b2c3238..92cb5e421 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/select.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/select.cpp @@ -70,17 +70,13 @@ namespace nd4j { auto tadsY = y->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims); - for (int e = 0; e < tadsX->size(); e++) { + for (int e = 0; e < tadsX.size(); e++) { if (!cond->e(e)) { - tadsZ->at(e)->assign(tadsY->at(e)); + tadsZ.at(e)->assign(tadsY.at(e)); } else { - tadsZ->at(e)->assign(tadsX->at(e)); + tadsZ.at(e)->assign(tadsX.at(e)); } } - - delete tadsX; - delete tadsY; - delete tadsZ; } } diff --git a/libnd4j/include/ops/declarable/generic/boolean/where.cpp b/libnd4j/include/ops/declarable/generic/boolean/where.cpp index b5800d3d6..6aa646cb6 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where.cpp @@ -59,17 +59,13 @@ namespace nd4j { auto tadsY = y->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims); - for (int e = 0; e < tadsX->size(); e++) { + for (int e = 0; e < tadsX.size(); e++) { if (!condition->e(e)) { - tadsZ->at(e)->assign(tadsY->at(e)); + tadsZ.at(e)->assign(tadsY.at(e)); } else { - tadsZ->at(e)->assign(tadsX->at(e)); + tadsZ.at(e)->assign(tadsX.at(e)); } } - - delete tadsX; - delete tadsY; - delete tadsZ; } } else { // in this case we return 2D matrix, which basically contains coordinates fo true diff --git a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp index 19a9a0ce9..c06ef07d1 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp @@ -89,16 +89,12 @@ namespace nd4j { auto tadsY = y->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims); - for (int e = 0; e < tadsX->size(); e++) { + for (int e = 0; e < tadsX.size(); e++) { if (!condition->e(e)) - tadsZ->at(e)->assign(tadsY->at(e)); + tadsZ.at(e)->assign(tadsY.at(e)); else - tadsZ->at(e)->assign(tadsX->at(e)); + tadsZ.at(e)->assign(tadsX.at(e)); } - - delete tadsX; - delete tadsY; - delete tadsZ; } } else { // in this case we return 2D matrix, which basically contains coordinates fo true diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp index 7d7e6f965..415a2c37a 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); - + BROADCAST_CHECK_EMPTY(x,y,z); auto tZ = BroadcastHelper::broadcastApply(nd4j::BroadcastOpsTuple::Add(), x, y, z); @@ -82,14 +82,12 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(epsNext); if (axisY.size() > 0) { auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(epsNext); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp index f11e18be6..24b673a8c 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp @@ -39,7 +39,7 @@ namespace nd4j { else if (tZ != z) { OVERWRITE_RESULT(tZ); } - + return ND4J_STATUS_OK; } DECLARE_SYN(set, assign); @@ -80,7 +80,6 @@ namespace nd4j { if (axisY.size() > 0) { auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(epsNext); } @@ -98,7 +97,7 @@ namespace nd4j { Nd4jLong *shapeE; Nd4jLong *shapeG; - + COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp index 8b894ac6d..32a7d7d65 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp @@ -28,7 +28,7 @@ namespace nd4j { namespace ops { BROADCASTABLE_OP_IMPL(tf_atan2, 0, 0) { - + auto y = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); @@ -36,8 +36,8 @@ BROADCASTABLE_OP_IMPL(tf_atan2, 0, 0) { BROADCAST_CHECK_EMPTY(x,y,z); // auto tZ = BroadcastHelper::template broadcastApply>(y, x, z); - x->applyTrueBroadcast(nd4j::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), y, z, true); - + x->applyTrueBroadcast(nd4j::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), *y, *z, true); + // if (tZ == nullptr) // return ND4J_STATUS_KERNEL_FAILURE; // else if (tZ != z) { diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp index 84d739ee2..1811781f1 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp @@ -81,7 +81,7 @@ namespace nd4j { // Y gradient //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); gradY->assign((*epsNext) * (*x) / ((*y) * (*y))); - gradY->applyTransform(transform::Neg, nullptr, nullptr); + gradY->applyTransform(transform::Neg, *gradY); } else if (y->isScalar()) { // scalar case @@ -91,17 +91,17 @@ namespace nd4j { //tmpX.printBuffer("SumX"); //tmp.printBuffer("Sum Eps"); gradY->assign(tmp * tmpX / ((*y) * (*y))); - gradY->applyTransform(transform::Neg, nullptr, nullptr); + gradY->applyTransform(transform::Neg, *gradY); - //epsNext->applyLambda(lambdaS, gradX); - epsNext->applyScalarArr(scalar::Divide, y, gradX, nullptr); + //epsNext->applyLambda(lambdaS, *gradX); + epsNext->applyScalarArr(scalar::Divide, *y, *gradX); } else { // broadcast case auto preX = *epsNext / *y; NDArray negX(*x); - x->applyTransform(transform::Neg, &negX); + x->applyTransform(transform::Neg, negX); auto preY = *epsNext * negX / ((*y) * (*y)); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); @@ -110,14 +110,12 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp index ea60c2f21..d442d89e7 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp @@ -69,7 +69,7 @@ namespace nd4j { std::unique_ptr tmpResult(op.execute({x, y}, {}, {}, {})); if (gradY->rankOf() == gradX->rankOf()) - epsNext->applyPairwiseTransform(pairwise::Multiply, tmpResult->at(0), gradY, nullptr); + epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY); else // epsNext is greater than gradY { std::vector dims(epsNext->rankOf() * 2); @@ -78,7 +78,7 @@ namespace nd4j { dims[d * 2 + 1] = 1; } auto tempIn((*tmpResult->at(0))(dims)); - (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, &tempIn, gradY, nullptr); + (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp index d2a9f6260..d50ffacaa 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp @@ -79,42 +79,42 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) { const Nd4jLong yLen = y->lengthOf(); if(x->isScalar() && y->isScalar()) { // both are scalars - y->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); - x->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); + y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); + x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); //dLdx->assign((*y) * (*dLdz)); //dLdy->assign((*x) * (*dLdz)); } else if(x->isScalar()) { // x is scalar and y is not dLdx->assign((*y * *dLdz).reduceNumber(reduce::Sum)); - dLdz->applyScalarArr(scalar::Multiply, x, dLdy, nullptr); + dLdz->applyScalarArr(scalar::Multiply, *x, *dLdy); //dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true); } else if(y->isScalar()) { // y is scalar and x is not dLdy->assign((*x * *dLdz).reduceNumber(reduce::Sum)); - dLdz->applyScalarArr(scalar::Multiply, y, dLdx); - } + dLdz->applyScalarArr(scalar::Multiply, *y, *dLdx); + } else if(x->isSameShape(y)) { - x->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); - y->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); + x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); + y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); } else if (x->isSameShape(dLdz)) { - + auto yTiled = NDArray(dLdz, false, block.launchContext()); y->tile(yTiled); std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo()); - - dLdy->assign( (*x * *dLdz).reduceAlongDims(reduce::Sum, axesForY) ); - yTiled.applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); - } + + dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) ); + yTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); + } else if (y->isSameShape(dLdz)) { auto xTiled = NDArray(dLdz, false, block.launchContext()); x->tile(xTiled); std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo()); - - dLdx->assign( (*y * *dLdz).reduceAlongDims(reduce::Sum, axesForX) ); - xTiled.applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); + + dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) ); + xTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); } else { @@ -124,16 +124,16 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) { y->tile(yTiled); std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo()); std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo()); - - dLdx->assign( (*y * *dLdz).reduceAlongDims(reduce::Sum, axesForX) ); - dLdy->assign( (*x * *dLdz).reduceAlongDims(reduce::Sum, axesForY) ); + + dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) ); + dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) ); } return Status::OK(); } DECLARE_SHAPE_FN(multiply_bp) { - + auto xShapeInfo = inputShape->at(0); auto yShapeInfo = inputShape->at(1); @@ -181,8 +181,8 @@ DECLARE_SHAPE_FN(multiply_bp) { T tmpX = x->template reduceNumber>(); gradY->assign(tmpX); - - epsNext->applyLambda(lambdaS, gradX); + + epsNext->applyLambda(lambdaS, *gradX); } else { // broadcast case @@ -201,7 +201,7 @@ DECLARE_SHAPE_FN(multiply_bp) { auto sum = preX->template reduceAlongDimension>(axisX); gradX->assign(sum); delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp index 7f7efd80c..56f77737d 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp @@ -16,6 +16,7 @@ // // @author raver119@gmail.com +// @author Oleh Semeniv (oleg.semeniv@gmail.com) // #include @@ -25,7 +26,7 @@ #include namespace nd4j { - namespace ops { +namespace ops { BROADCASTABLE_OP_IMPL(Pow, 0, 0) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); @@ -51,7 +52,76 @@ namespace nd4j { ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); } - } + + CUSTOM_OP_IMPL(Pow_bp, 3, 2, false, 0, 0) { + + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto dLdz = INPUT_VARIABLE(2); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdy = OUTPUT_VARIABLE(1); + + Nd4jLong* dLdzShapeInfo = nullptr; + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->getShapeInfo(), y->getShapeInfo(), true, dLdzShapeInfo, block.getWorkspace()); + REQUIRE_TRUE(areShapesBroadcastable, 0, "POW_BP OP: the shapes of x %s" + " and y %s are not suitable for broadcast !", + ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); + REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, + "POW_BP OP: wrong shape of next epsilon array (dLdOut)," + " expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), ShapeUtils::shapeAsString(dLdz).c_str()); + + // dL/dy = x^y * log(x) * dL/dz + auto temp = x->applyTrueBroadcast(BroadcastOpsTuple::Pow(), *y); // a = x^y + x->applyTransform(transform::Log, *dLdx); // b = log(x) + dLdx->applyScalar(nd4j::scalar::ReplaceNans, 0, *dLdx); + temp *= *dLdx; // c = b*a + temp *= *dLdz; // dL/dy = c * dL/dz + if (dLdy->isSameShape(*dLdz)) { + dLdy->assign(temp); + } + else { + std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo()); + dLdy->assign(temp.reduceAlongDimension(reduce::Sum, axesForY)); // dL/dy = sum(c * dL/dz) + } + + // dL/dx = y*x^(y-1) * dL/dz + x->applyTrueBroadcast(BroadcastOpsTuple::PowDerivative(), *y, temp); // a = y*x^(y-1) + temp *= *dLdz; // dLdx = a*dL/dz + + if (dLdx->isSameShape(*dLdz)) { + dLdx->assign(temp); // dLdx = a*dL/dz + } + else { + std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo()); + dLdx->assign(temp.reduceAlongDimension(reduce::Sum, axesForX)); // dLdx = a*dL/dz + } + + return Status::OK(); + } + + DECLARE_SHAPE_FN(Pow_bp) { + + auto xShapeInfo = inputShape->at(0); + auto yShapeInfo = inputShape->at(1); + + Nd4jLong* dLdxShapeInfo = nullptr; + Nd4jLong* dLdyShapeInfo = nullptr; + + COPY_SHAPE(xShapeInfo, dLdxShapeInfo); + COPY_SHAPE(yShapeInfo, dLdyShapeInfo); + + return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo)); + } + + DECLARE_TYPES(Pow_bp) { + getOpDescriptor() + ->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS }) + ->setAllowedOutputTypes({ ALL_FLOATS }); // TODO maybe wourth to add ALL_INTS + } + +} } #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp index 7b4e374d5..3e7445cf0 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp @@ -71,7 +71,7 @@ namespace nd4j { // X gradient //epsNext->applyPairwiseLambda(y, lambdaX, gradX); - epsNext->applyPairwiseTransform(pairwise::Divide, y, gradX, nullptr); + epsNext->applyPairwiseTransform(pairwise::Divide, *y, *gradX); // Y gradient //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); @@ -84,16 +84,16 @@ namespace nd4j { auto tmp = epsNext->reduceNumber(reduce::Sum); auto tmpX = x->reduceNumber(reduce::Sum); gradY->assign(tmp * -tmpX / ((*y) * (*y))); - + //epsNext->applyLambda(lambdaS, gradX); - epsNext->applyScalarArr(scalar::Divide, y, gradX, nullptr); + epsNext->applyScalarArr(scalar::Divide, *y, *gradX); } else { // broadcast case auto preX = *epsNext / *y; NDArray negX(*x); - x->applyTransform(transform::Neg, &negX); + x->applyTransform(transform::Neg, negX); auto preY = *epsNext * negX / ((*y) * (*y)); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); @@ -102,14 +102,12 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp index 6abe8ff9c..04c4c926e 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); REQUIRE_TRUE(!x->isB(), 0, "REVERSEDIVIDE OP: you can't divide by bool array!"); - x->applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); + x->applyTrueBroadcast(BROADCAST(ReverseDivide), *y, *z, true); return Status::OK(); } @@ -67,7 +67,7 @@ namespace nd4j { // X gradient //epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); - gradX->applyTransform(transform::Neg, nullptr, nullptr); + gradX->applyTransform(transform::Neg, *gradX); // Y gradient //epsNext->applyPairwiseLambda(x, lambdaY, gradY); gradY->assign((*epsNext) / (*x)); @@ -78,14 +78,14 @@ namespace nd4j { gradY->assign(tmp / tmpX); gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); - gradX->applyTransform(transform::Neg, nullptr, nullptr); + gradX->applyTransform(transform::Neg, *gradX); } else { // broadcast case auto preY = (*epsNext) / (*x); auto preX = *epsNext * (*y) / ((*x) * (*x)); - preX.applyTransform(transform::Neg, nullptr, nullptr); + preX.applyTransform(transform::Neg, preX); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); @@ -93,14 +93,12 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp index af282fe7c..dbb14c78b 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp @@ -61,13 +61,13 @@ namespace nd4j { if (x->isSameShape(y)) { // PWT case case - epsNext->applyTransform(transform::Neg, gradX, nullptr); + epsNext->applyTransform(transform::Neg, *gradX); gradY->assign(epsNext); } else if (y->isScalar()) { // scalar case auto tmp = epsNext->reduceNumber(reduce::Sum); gradY->assign(tmp); - epsNext->applyTransform(transform::Neg, gradX, nullptr); + epsNext->applyTransform(transform::Neg, *gradX); } else { // broadcastable auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); @@ -75,20 +75,18 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); - sum->applyTransform(transform::Neg, gradX); - delete sum; + sum.applyTransform(transform::Neg, *gradX); } else { - epsNext->applyTransform(transform::Neg, gradX, nullptr); + epsNext->applyTransform(transform::Neg, *gradX); } if (axisY.size() > 0) { auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else { gradY->assign(epsNext); } - } + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp index 280a09857..ae9c93d4d 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp @@ -87,7 +87,7 @@ namespace nd4j { // scalar case auto tmpX = x->reduceNumber(reduce::Sum); gradY->assign(tmpX); - + //epsNext->applyPairwiseLambda(x, lambdaS, gradX); gradX->assign((*epsNext) * ts * ((*x) - (*y))); } else { @@ -98,37 +98,31 @@ namespace nd4j { auto targetShape = epsNext->getShapeAsVector(); - preX->tileToShape(targetShape); - preY->tileToShape(targetShape); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); //epsNext->applyTriplewiseLambda(x, y, lambdaX, preX); //epsNext->applyTriplewiseLambda(x, y, lambdaY, preY); auto resX = (*epsNext) * ts * ((*x) - (*y)); - preX->assign(resX); + preX.assign(resX); auto resY = (*epsNext) * ts * ((*y) - (*x)); - preY->assign(resY); + preY.assign(resY); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); - - - delete preX; - delete preY; } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp index 76f2d6830..40bbb8559 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp @@ -62,7 +62,7 @@ namespace nd4j { if (x->isSameShape(y)) { // PWT case case - epsNext->applyTransform(transform::Neg, gradY, nullptr); + epsNext->applyTransform(transform::Neg, *gradY); gradX->assign(epsNext); } else if (y->isScalar()) { // scalar case @@ -77,18 +77,16 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(epsNext); if (axisY.size() > 0) { auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); - sum->applyTransform(transform::Neg, gradY); - delete sum; + sum.applyTransform(transform::Neg, *gradY); } else { - epsNext->applyTransform(transform::Neg, gradY); + epsNext->applyTransform(transform::Neg, *gradY); } - } + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/compat/README.md b/libnd4j/include/ops/declarable/generic/compat/README.md new file mode 100644 index 000000000..ff44ae4c1 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/compat/README.md @@ -0,0 +1 @@ +This folder contains operations required for compatibility with TF and other frameworks. \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp new file mode 100644 index 000000000..4a84dbdac --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp @@ -0,0 +1,73 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_split_string) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(compat_sparse_to_dense, 4, 1, false, 0, 0) { + auto indices = INPUT_VARIABLE(0); + auto shape = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + NDArray *def = nullptr; + + auto output = OUTPUT_VARIABLE(0); + + if (block.width() > 3) + def = INPUT_VARIABLE(3); + + nd4j::ops::helpers::compat_sparse_to_dense(*values, *indices, def, *output); + + return Status::OK(); + }; + + DECLARE_SHAPE_FN(compat_sparse_to_dense) { + auto indices = INPUT_VARIABLE(0); + auto shape = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + + if (block.width() > 3) { + auto def = INPUT_VARIABLE(3); + + REQUIRE_TRUE(def->dataType() == values->dataType() && def->isScalar(), 0, "compat_sparse_to_dense: default value must be a scalar of the same data type as actual values") + }; + + auto dtype = values->dataType(); + + // basically output shape is defined by the type of input, and desired shape input + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape->getBufferAsVector())); + } + + DECLARE_TYPES(compat_sparse_to_dense) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) // indices + ->setAllowedInputTypes(1, {ALL_INTS}) // shape + ->setAllowedInputTypes(2,nd4j::DataType::ANY) // sparse values + ->setAllowedInputTypes(3,nd4j::DataType::ANY) // default value + ->setAllowedOutputTypes(nd4j::DataType::ANY); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp new file mode 100644 index 000000000..9d7b57ee4 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp @@ -0,0 +1,140 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_split_string) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + auto indices = OUTPUT_VARIABLE(0); + auto values = OUTPUT_VARIABLE(1); + + auto d = delim->e(0); + + input->syncToHost(); + delim->syncToHost(); + + // output rank N+1 wrt input rank + std::vector ocoords(input->rankOf() + 1); + std::vector icoords(input->rankOf()); + + // getting buffer lengths + // FIXME: it'll be bigger, since it'll include delimiters, + auto outputLength = StringUtils::byteLength(*input); + + uint64_t ss = 0L; + Nd4jLong ic = 0L; + // loop through each string within tensor + for (auto e = 0L; e < input->lengthOf(); e++) { + // now we should map substring to indices + auto s = input->e(e); + + // getting base index + shape::index2coords(e, input->shapeInfo(), icoords.data()); + + // getting number of substrings + auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1; + + // filling output indices + for (uint64_t f = 0; f < cnt; f++) { + for (auto v: icoords) + indices->p(ic++, v); + + // last index + indices->p(ic++, f); + } + + ss += cnt; + } + + // process strings now + std::vector strings; + for (auto e = 0L; e < input->lengthOf(); e++) { + auto split = StringUtils::split(input->e(e), d); + + for (const auto &s:split) + strings.emplace_back(s); + } + + // now once we have all strings in single vector time to fill + auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings); + auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); + + // for CUDA mostly + values->dataBuffer()->allocatePrimary(); + values->dataBuffer()->expand(blen); + memcpy(values->buffer(), tmp.buffer(), blen); + values->tickWriteHost(); + + // special case, for future use + indices->syncToDevice(); + values->syncToDevice(); + + // we have to tick buffers + values->dataBuffer()->writePrimary(); + values->dataBuffer()->readSpecial(); + + return Status::OK(); + }; + + DECLARE_SHAPE_FN(compat_string_split) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + auto d = delim->e(0); + + // count number of delimiter substrings in all strings within input tensor + uint64_t cnt = 0; + for (auto e = 0L; e < input->lengthOf(); e++) { + // FIXME: bad, not UTF-compatible + auto s = input->e(e); + + // each substring we see in haystack, splits string in two parts. so we should add 1 to the number of subarrays + cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1; + } + + // shape calculations + // virtual tensor rank will be N+1, for N rank input array, where data will be located at the biggest dimension + // values tensor is going to be vector always + // indices tensor is going to be vector with length equal to values.length * output rank + + auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt, nd4j::DataType::UTF8); + auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt * (input->rankOf() + 1), nd4j::DataType::INT64); + + return SHAPELIST(indicesShape, valuesShape); + } + + DECLARE_TYPES(compat_string_split) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_STRINGS}) + ->setAllowedOutputTypes(0, {ALL_INDICES}) + ->setAllowedOutputTypes(1, {ALL_STRINGS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp index 24f96f7a7..8591d3449 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp @@ -47,8 +47,7 @@ namespace nd4j { } // just memcpy data -// output->dataBuffer()->copyBufferFrom(*input->dataBuffer()); // as variant - DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); // this is modern approach + DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp b/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp index 472cb060d..5296e8844 100644 --- a/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp +++ b/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp @@ -26,7 +26,7 @@ namespace nd4j { namespace ops { /** * This operation is, basically IF statement - * + * * arg_0 is our "signal" * arg_1 is condition that will determine transition */ @@ -41,10 +41,10 @@ namespace nd4j { // but we'll ensure only one node is active, and other is disabled if (condition->e(0) == 0) { block.setBranch(0); - this->storeResult(block, 0, input->dup()); + this->storeResult(block, 0, new NDArray(input->dup())); } else { block.setBranch(1); - this->storeResult(block, 1, *input->dup()); + this->storeResult(block, 1, new NDArray(input->dup())); } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h b/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h index 5e91641ca..e497be416 100644 --- a/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h +++ b/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h @@ -30,7 +30,7 @@ namespace nd4j { namespace ops { class BroadcastHelper { - public: + public: static FORCEINLINE NDArray* broadcastApply(nd4j::BroadcastOpsTuple op, NDArray* x, NDArray* y, NDArray* z, ExtraArguments *extraArgs = nullptr) { if(x->isEmpty() || y->isEmpty()) { @@ -42,34 +42,34 @@ namespace nd4j { std::unique_ptr ptr; if (!Environment::getInstance()->isExperimentalBuild()) { if (y->dataType() != x->dataType()) { - y = y->cast(x->dataType()); + y = new NDArray(y->cast(x->dataType())); std::unique_ptr ptr2(y); ptr.swap(ptr2); } } if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { - x->applyPairwiseTransform(op.p, y, z, nullptr); + x->applyPairwiseTransform(op.p, *y, *z); } else if (!x->isScalar() && y->isScalar()) { - x->applyScalarArr(op.s, const_cast(y), z); + x->applyScalarArr(op.s, const_cast(*y), *z); } else if (x->isScalar() && !y->isScalar()) { if (z->isSameShape(y)) { if (op.s == scalar::Add || op.s == scalar::Multiply ) { - y->applyScalarArr(op.s, x, z, nullptr); + y->applyScalarArr(op.s, *x, *z); } else if (op.s == scalar::SquaredSubtract) { - y->applyScalarArr(scalar::SquaredReverseSubtract, x, z, nullptr); + y->applyScalarArr(scalar::SquaredReverseSubtract, *x, *z); } else if (op.s == scalar::Subtract) { - y->applyScalarArr(scalar::ReverseSubtract, x, z, nullptr); + y->applyScalarArr(scalar::ReverseSubtract, *x, *z); } else if (op.s == scalar::Divide) { - y->applyScalarArr(scalar::ReverseDivide, x, z, nullptr); + y->applyScalarArr(scalar::ReverseDivide, *x, *z); } else if (op.s == scalar::Pow) { - y->applyScalarArr(scalar::ReversePow, x, z, nullptr); + y->applyScalarArr(scalar::ReversePow, *x, *z); } else if (op.s == scalar::ReverseSubtract) { - y->applyScalarArr(scalar::Subtract, x, z, nullptr); + y->applyScalarArr(scalar::Subtract, *x, *z); } else if (op.s == scalar::ReverseDivide) { - y->applyScalarArr(scalar::Divide, x, z, nullptr); + y->applyScalarArr(scalar::Divide, *x, *z); } else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || op.s == scalar::AMaxPairwise || op.s == scalar::AMinPairwise) { - y->applyScalarArr(op.s, x, z, nullptr); + y->applyScalarArr(op.s, *x, *z); } else if (op.s == scalar::CopyPws) { z->assign(y); } else { @@ -84,9 +84,9 @@ namespace nd4j { return tZ; } } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() - x->applyScalarArr(op.s, const_cast(y), z, nullptr); + x->applyScalarArr(op.s, const_cast(*y), *z); } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, y, z, true, extraArgs); + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); return z; } else { auto sx = ShapeUtils::shapeAsString(x); @@ -107,16 +107,16 @@ namespace nd4j { } if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { - x->applyPairwiseTransform(op.p, y, z, nullptr); + x->applyPairwiseTransform(op.p, *y, *z); } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, y, z, true, extraArgs); + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); return z; } else if (!x->isScalar() && y->isScalar()) { - x->applyScalarArr(op.s, const_cast(y), z); + x->applyScalarArr(op.s, const_cast(*y), *z); } else if (x->isScalar() && !y->isScalar()) { if (z->isSameShape(y)) { //z->assign(x); - x->applyPairwiseTransform(op.p, y, z, extraArgs); + x->applyPairwiseTransform(op.p, *y, *z, extraArgs); return z; } else { auto v = y->getShapeAsVector(); @@ -125,9 +125,9 @@ namespace nd4j { return tZ; } } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() - x->applyScalarArr(op.s, const_cast(y), z, nullptr); + x->applyScalarArr(op.s, const_cast(*y), *z); } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, y, z, true, extraArgs); + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); return z; } else { auto sx = ShapeUtils::shapeAsString(x); diff --git a/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp new file mode 100644 index 000000000..f1f8522d7 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author AbdelRauf (rauf@konduit.ai) +// + +#include +#include +#include +#include + +namespace nd4j { +namespace ops { + +CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, true, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) + return Status::OK(); + + const int rank = input->rankOf(); + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, "HSVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "HSVtoRGB: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC)); + + helpers::transformHsvRgb(block.launchContext(), input, output, dimC); + + return Status::OK(); +} + +DECLARE_TYPES(hsv_to_rgb) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); +} + + +} +} diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp new file mode 100644 index 000000000..aa2bec9da --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace nd4j { +namespace ops { + +CUSTOM_OP_IMPL(rgb_to_grs, 1, 1, false, 0, 0) { + + const auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const int inRank = input->rankOf(); + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) : inRank - 1; + + REQUIRE_TRUE(inRank >= 1, 0, "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", inRank); + if (argSize > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < inRank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -inRank, inRank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBGrayScale: operation expects 3 channels (R, G, B) in last dimention, but received %i instead", input->sizeAt(dimC)); + + helpers::transformRgbGrs(block.launchContext(), *input, *output, dimC); + return Status::OK(); +} + +DECLARE_TYPES(rgb_to_grs) { + getOpDescriptor()->setAllowedInputTypes( {ALL_INTS, ALL_FLOATS} ) + ->setSameMode(true); +} + +DECLARE_SHAPE_FN(rgb_to_grs) { + + const auto input = INPUT_VARIABLE(0); + const int inRank = input->rankOf(); + + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) : inRank - 1; + + REQUIRE_TRUE(inRank >= 1, 0, "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", inRank); + if (argSize > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < inRank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -inRank, inRank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoGrayScale: operation expects 3 channels (R, B, G) in last dimention, but received %i", dimC); + + auto nShape = input->getShapeAsVector(); + nShape[dimC] = 1; + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(input->dataType(), input->ordering(), nShape)); +} + +} +} diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp new file mode 100644 index 000000000..2ba45bea9 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author AbdelRauf (rauf@konduit.ai) +// + + + +#include +#include +#include +#include + +namespace nd4j { +namespace ops { + +CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, true, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) + return Status::OK(); + + const int rank = input->rankOf(); + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, "RGBtoHSV: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); + + helpers::transformRgbHsv(block.launchContext(), input, output, dimC); + + return Status::OK(); +} + + +DECLARE_TYPES(rgb_to_hsv) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); +} + + +} +} diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp new file mode 100644 index 000000000..6d202ee4a --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author AbdelRauf (rauf@konduit.ai) + // + +#include +#include +#include +#include + +namespace nd4j { + namespace ops { + + + + CONFIGURABLE_OP_IMPL(rgb_to_yiq, 1, 1, true, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) + return Status::OK(); + + const int rank = input->rankOf(); + const int arg_size = block.getIArguments()->size(); + const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, "RGBtoYIQ: Fails to meet the rank requirement: %i >= 1 ", rank); + if (arg_size > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYIQ: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); + + helpers::transformRgbYiq(block.launchContext(), input, output, dimC); + + return Status::OK(); + } + + + DECLARE_TYPES(rgb_to_yiq) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + } +} diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp new file mode 100644 index 000000000..58dd8a432 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + + + +#include +#include +#include +#include + +namespace nd4j { +namespace ops { + +CONFIGURABLE_OP_IMPL(rgb_to_yuv, 1, 1, true, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // just skip op if input is empty + if (input->isEmpty()) + return Status::OK(); + + const int rank = input->rankOf(); + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, "RGBtoYUV: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYUV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); + + helpers::transformRgbYuv(block.launchContext(), *input, *output, dimC); + + return Status::OK(); +} + +DECLARE_TYPES(rgb_to_yuv) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); +} + +} +} diff --git a/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp new file mode 100644 index 000000000..287aa150a --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author AbdelRauf (rauf@konduit.ai) +// + +#include +#include +#include +#include + +namespace nd4j { + namespace ops { + + + + CONFIGURABLE_OP_IMPL(yiq_to_rgb, 1, 1, true, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) + return Status::OK(); + + const int rank = input->rankOf(); + const int arg_size = block.getIArguments()->size(); + const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, "YIQtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (arg_size > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YIQtoRGB: operation expects 3 channels (Y, I, Q), but got %i instead", input->sizeAt(dimC)); + + helpers::transformYiqRgb(block.launchContext(), input, output, dimC); + + return Status::OK(); + } + + + DECLARE_TYPES(yiq_to_rgb) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp new file mode 100644 index 000000000..90ca217ce --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace nd4j { +namespace ops { + +CONFIGURABLE_OP_IMPL(yuv_to_rgb, 1, 1, true, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // just skip op if input is empty + if (input->isEmpty()) + return Status::OK(); + + const int rank = input->rankOf(); + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, "YUVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YUVtoRGB: operation expects 3 channels (Y, U, V), but got %i instead", input->sizeAt(dimC)); + + helpers::transformYuvRgb(block.launchContext(), *input, *output, dimC); + + return Status::OK(); +} + +DECLARE_TYPES(yuv_to_rgb) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); +} + + +} +} diff --git a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp index c5bc3ca6c..2d854ae0b 100644 --- a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp @@ -51,12 +51,12 @@ namespace nd4j { std::vector axis = ShapeUtils::evalDimsToExclude(array->rankOf(), {0}); auto tads = array->allTensorsAlongDimension( axis); - for (int e = 0; e < tads->size(); e++) { + for (int e = 0; e < tads.size(); e++) { auto idx = indices->e(e); - if (idx >= tads->size()) + if (idx >= tads.size()) return ND4J_STATUS_BAD_ARGUMENTS; - auto arr = tads->at(e)->dup(array->ordering()); + auto arr = new NDArray(tads.at(e)->dup(array->ordering())); auto res = list->write(idx, arr); if (res != ND4J_STATUS_OK) return res; @@ -65,7 +65,6 @@ namespace nd4j { if (!hasList) //OVERWRITE_RESULT(list); setupResultList(list, block); - delete tads; return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/list/split_list.cpp b/libnd4j/include/ops/declarable/generic/list/split_list.cpp index f2399c9d3..5a403dd06 100644 --- a/libnd4j/include/ops/declarable/generic/list/split_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/split_list.cpp @@ -55,7 +55,7 @@ namespace nd4j { std::vector indices(2 * array->rankOf(), 0); for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { int c_size = sizes->e(e); - + REQUIRE_TRUE(c_size > 0, 0, "Slice size should have postive value, but got %i instead", c_size); REQUIRE_TRUE(cnt < array->sizeAt(0) && cnt + c_size <= array->sizeAt(0), 0, "Slices size should NOT be higher then number of TADs of source array. Source size: [%i]; Slice start: [%i]; Slice size: [%i]", array->sizeAt(0), cnt, c_size); @@ -63,11 +63,11 @@ namespace nd4j { indices[0] = cnt; indices[1] = cnt + c_size; cnt += c_size; - + auto subarray = (*array)(indices); - auto status = list->write(e, subarray.dup(array->ordering())); - + auto status = list->write(e, new NDArray(subarray.dup(array->ordering()))); + if (status != ND4J_STATUS_OK) return status; } diff --git a/libnd4j/include/ops/declarable/generic/list/write_list.cpp b/libnd4j/include/ops/declarable/generic/list/write_list.cpp index 8ac1935b3..c9b32234e 100644 --- a/libnd4j/include/ops/declarable/generic/list/write_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/write_list.cpp @@ -39,7 +39,7 @@ namespace nd4j { //nd4j_printf("Writing [%i]:\n", idx->e(0)); //input->printShapeInfo("input shape"); //input->printIndexedBuffer("input buffer"); - Nd4jStatus result = list->write(idx->e(0), input->dup()); + Nd4jStatus result = list->write(idx->e(0), new NDArray(input->dup())); auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); //res->printShapeInfo("Write_list 2 output shape"); @@ -52,7 +52,7 @@ namespace nd4j { auto input = INPUT_VARIABLE(1); auto idx = INT_ARG(0); - Nd4jStatus result = list->write(idx, input->dup()); + Nd4jStatus result = list->write(idx, new NDArray(input->dup())); auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); //res->printShapeInfo("Write_list 1 output shape"); diff --git a/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp b/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp index d028e5af8..ba488df65 100644 --- a/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp @@ -169,10 +169,10 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) { NDArray E = *predictions - *labels; // dE_i/dp_i = sign(p_i - y_i) - E.applyTransform(nd4j::transform::Sign, dLdp); // dE/dp + E.applyTransform(nd4j::transform::Sign, *dLdp); // dE/dp // dE_i/dy_i = -sign(p_i - y_i) - E.applyTransform(nd4j::transform::Abs); + E.applyTransform(nd4j::transform::Abs, E); switch (reductionMode) { @@ -184,7 +184,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) { dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -210,7 +210,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) { *dLdw = 0.; else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -238,7 +238,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp b/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp index 0b4e3fe89..7fe75c03a 100644 --- a/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp @@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); } - NDArray E = 1. - (*predictions * *labels).reduceAlongDims(reduce::Sum, {dim}, true); + NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); // perform weights broadcasting/tile to E if it is necessary auto weightsBroad = weights; @@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. output->assign(&E); break; - + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array output->assign(E.reduceNumber(reduce::Sum)); break; @@ -79,12 +79,12 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { NDArray sum; if (weights->isScalar()) sum = *weights * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) *output = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / sum); break; } @@ -99,9 +99,9 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { if (numOfNonZeroWeights == 0) *output = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - + break; } } @@ -111,7 +111,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } @@ -124,7 +124,7 @@ DECLARE_TYPES(cosine_distance_loss) { ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(cosine_distance_loss) { - // labels and predictions must have the same shapes + // labels and predictions must have the same shapes auto predictionsShapeInfo = inputShape->at(0); auto weightsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(2); @@ -194,7 +194,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { // input dimension can't be larger than labels/predictions/weights rank REQUIRE_TRUE(dim < labels->rankOf(), 0, "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labels->rankOf()); - NDArray E = 1. - (*predictions * *labels).reduceAlongDims(reduce::Sum, {dim}, true); + NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); // perform weights broadcasting/tile to E if it is necessary auto weightsBroad = weights; @@ -216,7 +216,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { else { if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -249,7 +249,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -284,7 +284,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeights; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp index b62dffad8..8670bf9e1 100644 --- a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp @@ -52,7 +52,7 @@ namespace nd4j { // We first need to convert binary labels to -1/1 labels (as floats) NDArray E = 1.f - (*labels * 2.f - 1.f) * (*logits); - E.applyScalar(scalar::RELU, 0.0f, &E); + E.applyScalar(scalar::RELU, 0.0f, E); // multiply E on weights E *= *weightsBroad; @@ -172,11 +172,11 @@ namespace nd4j { NDArray z = (*labels * 2.f - 1.f); NDArray E = 1.f - z * (*logits); - E.applyScalar(scalar::RELU, 0.0f, &E); + E.applyScalar(scalar::RELU, 0.0f, E); // turn E into gradient mask NDArray gradientMask(E.getShapeInfo(), block.getWorkspace()); - E.applyTransform(nd4j::transform::Sign, &gradientMask); + E.applyTransform(nd4j::transform::Sign, gradientMask); dLdp->assign(-z * gradientMask); dLdl->assign(-2.f * (*logits) * gradientMask); @@ -192,7 +192,7 @@ namespace nd4j { dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -220,7 +220,7 @@ namespace nd4j { *dLdw = 0.; else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -249,7 +249,7 @@ namespace nd4j { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp index 3e7686a3d..e844b4126 100644 --- a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp @@ -46,17 +46,17 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HUBER_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HUBER_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - + // perform weights broadcasting/tile to predictions if needed auto weightsBroad = weights; if(!weights->isScalar() && !weights->isSameShape(predictions)) weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); auto error = *predictions - *labels; - error.applyTransform(transform::Abs); + error.applyTransform(transform::Abs, error); NDArray quadratic(error.getShapeInfo(), block.getWorkspace()); - error.applyScalar(scalar::MinPairwise, delta, &quadratic); - + error.applyScalar(scalar::MinPairwise, delta, quadratic); + NDArray E = quadratic * quadratic * 0.5f + (error - quadratic)*delta; // multiply E on weights @@ -75,12 +75,12 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { NDArray sum; if (weights->isScalar()) sum = *weights * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) *output = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / sum); break; } @@ -104,7 +104,7 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } @@ -173,24 +173,24 @@ DECLARE_SHAPE_FN(huber_loss) { NDArray diff = *predictions - *labels; NDArray absDiff(diff); - absDiff.applyTransform(transform::Abs); + absDiff.applyTransform(transform::Abs, absDiff); NDArray quadratic(absDiff); - absDiff.applyScalar(scalar::MinPairwise, delta, &quadratic); + absDiff.applyScalar(scalar::MinPairwise, delta, quadratic); NDArray E = quadratic * quadratic * 0.5f + (absDiff - quadratic)*delta; NDArray lteMask(diff.getShapeInfo(), BOOL, true, block.launchContext()); - absDiff.applyScalar(scalar::LessThanOrEqual, delta, <eMask); + absDiff.applyScalar(scalar::LessThanOrEqual, delta, lteMask); NDArray gtMask(diff.getShapeInfo(), BOOL, true, block.launchContext()); - absDiff.applyScalar(scalar::GreaterThan, delta, >Mask); + absDiff.applyScalar(scalar::GreaterThan, delta, gtMask); NDArray signDiff(diff); - diff.applyTransform(transform::Sign, &signDiff); + diff.applyTransform(transform::Sign, signDiff); - auto gtMaskFloat = *gtMask.cast(diff.dataType()); - auto lteMaskFloat = *lteMask.cast(diff.dataType()); + auto gtMaskFloat = gtMask.cast(diff.dataType()); + auto lteMaskFloat = lteMask.cast(diff.dataType()); dLdp->assign( lteMaskFloat * diff + gtMaskFloat * delta * signDiff); @@ -207,7 +207,7 @@ DECLARE_SHAPE_FN(huber_loss) { dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -235,7 +235,7 @@ DECLARE_SHAPE_FN(huber_loss) { *dLdw = 0.; else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -264,7 +264,7 @@ DECLARE_SHAPE_FN(huber_loss) { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp index 33d5c03ec..f83947c69 100644 --- a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp @@ -29,11 +29,11 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { - + auto predictions = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto labels = INPUT_VARIABLE(2); - + auto output = OUTPUT_VARIABLE(0); int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" @@ -48,7 +48,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - + // perform weights broadcasting/tile to predictions if needed auto weightsBroad = weights; if(!weights->isScalar() && !weights->isSameShape(predictions)) @@ -58,7 +58,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { // multiply E on weights E *= *weightsBroad; - + switch (reductionMode) { case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. output->assign(E); @@ -72,12 +72,12 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { NDArray sum; if (weights->isScalar()) sum = *weights * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) *output = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / sum); break; } @@ -101,13 +101,13 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(log_loss) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -118,11 +118,11 @@ DECLARE_SHAPE_FN(log_loss) { auto weightsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(2); - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); @@ -132,7 +132,7 @@ DECLARE_SHAPE_FN(log_loss) { outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); else // in this case output has the same shape as labels and predictions outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - + return SHAPELIST(outShapeInfo); } @@ -143,33 +143,33 @@ DECLARE_SHAPE_FN(log_loss) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { - + auto predictions = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto labels = INPUT_VARIABLE(2); - + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" + int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients if(reductionMode == 0) reductionMode = 1; - - // FIXME: double? + + // FIXME: double? double epsilon = T_ARG(0); // input validation REQUIRE_TRUE(labels->isSameShape(predictions), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed + + // perform weights broadcasting/tile to labels if needed auto weightsBroad = weights; if(!weights->isScalar() && !weights->isSameShape(predictions)) weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); @@ -179,24 +179,24 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { NDArray onePlusEpsMinusPredict = (1. + epsilon) - *predictions; // dE_i/dp_i = (1-y_i)/(1-p_i+eps) - y_i/(p_i+eps) - dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - *labels / predictPlusEps); // dE/dp + dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - *labels / predictPlusEps); // dE/dp // dE_i/dy_i = log((1+2eps)/(p_i+eps) - 1) - ((1. + 2. * epsilon) / predictPlusEps - 1.).applyTransform(transform::Log, dLdl); // dE/dy + ((1. + 2. * epsilon) / predictPlusEps - 1.).applyTransform(transform::Log, *dLdl); // dE/dy NDArray E = -(*labels) * predictPlusEps.transform(transform::Log) - oneMinusLabels * onePlusEpsMinusPredict.transform(transform::Log); - + // process 3 possible reduction modes below - switch (reductionMode) { + switch (reductionMode) { case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array *dLdp *= *weightsBroad; *dLdl *= *weightsBroad; - + if(weights->isScalar()) dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -208,9 +208,9 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { NDArray sum; if (weights->isScalar()) sum = (*weights) * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) { *dLdp = 0.; *dLdl = 0.; @@ -221,27 +221,27 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { NDArray temp = *weightsBroad / sum; *dLdp *= temp; *dLdl *= temp; - + if(weights->isScalar()) *dLdw = 0.; else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); + } + else + dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); } break; } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; if(weights->isScalar()) { if(weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + else + numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); if (numOfNonZeroWeights == 0) { *dLdp = 0.; @@ -254,12 +254,12 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else dLdw->assign(E / numOfNonZeroWeightsScalar); - + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; *dLdp *= temp; *dLdl *= temp; @@ -270,13 +270,13 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(log_loss_grad) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -287,19 +287,19 @@ DECLARE_SHAPE_FN(log_loss_grad) { auto weightsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(2); - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); + DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp index 6f3d0c5dd..0d85c6e23 100644 --- a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp @@ -55,9 +55,9 @@ namespace ops { NDArray E(labels->getShapeInfo(), block.getWorkspace()); if (computeFullLoss) - labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, log_predictions, &E, nullptr); + labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E); else - labels->applyPairwiseTransform(pairwise::LogPoissonLoss, log_predictions, &E, nullptr); + labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E); // multiply E on weights @@ -176,19 +176,19 @@ namespace ops { NDArray E(labels->getShapeInfo(), block.getWorkspace()); if (computeFullLoss) { - labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, log_predictions, &E, nullptr); + labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E); NDArray rDiv(labels->getShapeInfo(), block.getWorkspace()); - labels->applyScalar(scalar::ReverseDivide, 0.5f, &rDiv); + labels->applyScalar(scalar::ReverseDivide, 0.5f, rDiv); dLdl->assign(rDiv + labels->transform(transform::Log) + -(*log_predictions)); } else { - labels->applyPairwiseTransform(pairwise::LogPoissonLoss, log_predictions, &E, nullptr); + labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E); dLdl->assign(-(*log_predictions)); } dLdp->assign(log_predictions->transform(transform::Exp) - (*labels)); - + switch (reductionMode) { case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array @@ -200,7 +200,7 @@ namespace ops { dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -228,7 +228,7 @@ namespace ops { *dLdw = 0.; else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -257,7 +257,7 @@ namespace ops { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp index 003ae815b..ef511921f 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp @@ -112,10 +112,10 @@ namespace nd4j { auto n = double(labels->sizeAt(1)); auto diffs = *predictions - *labels; - auto sumOfSquares = (diffs * diffs).reduceAlongDims(reduce::Sum, reductionIdx, true); + auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); - auto squareOfSum = diffs.reduceAlongDims(reduce::Sum, reductionIdx, true); - squareOfSum.applyScalar(scalar::Pow, 2); + auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); + squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); @@ -240,15 +240,15 @@ namespace nd4j { auto diffs = *predictions - *labels; std::vector reductionIdx = ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); - auto sumOfSquares = (diffs * diffs).reduceAlongDims(reduce::Sum, reductionIdx, true); + auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); - auto squareOfSum = diffs.reduceAlongDims(reduce::Sum, reductionIdx, true); - squareOfSum.applyScalar(scalar::Pow, 2); + auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); + squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); - auto sumPred = predictions->reduceAlongDims(reduce::Sum, reductionIdx, true); - auto sumLabel = labels->reduceAlongDims(reduce::Sum, reductionIdx, true); + auto sumPred = predictions->reduceAlongDimension(reduce::Sum, reductionIdx, true); + auto sumLabel = labels->reduceAlongDimension(reduce::Sum, reductionIdx, true); dLdp->assign(((diffs * n) - sumPred + sumLabel)*(8/(n*(n-1)))); @@ -273,7 +273,7 @@ namespace nd4j { dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -299,7 +299,7 @@ namespace nd4j { *dLdw = 0.; else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -327,7 +327,7 @@ namespace nd4j { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp index c519ab020..f446d0bf0 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp @@ -35,8 +35,8 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { auto output = OUTPUT_VARIABLE(0); int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - // inputs validation + + // inputs validation REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); @@ -45,13 +45,13 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // perform weights broadcasting/tile to labels if needed + // perform weights broadcasting/tile to labels if needed auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) + if(!weights->isScalar() && !weights->isSameShape(predictions)) weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); NDArray E(labels->getShapeInfo(), false, block.launchContext()); - predictions->applyPairwiseTransform(pairwise::SquaredSubtract, labels, &E, nullptr); + predictions->applyPairwiseTransform(pairwise::SquaredSubtract, *labels, E); // multiply E on weights E *= (*weightsBroad); @@ -60,7 +60,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. output->assign(&E); break; - + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array E.reduceNumber(reduce::Sum, *output); break; @@ -69,12 +69,12 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { NDArray sum; if (weights->isScalar()) sum = (*weights) * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) (*output) = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / sum); break; } @@ -101,12 +101,12 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } DECLARE_TYPES(mean_sqerr_loss) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -121,7 +121,7 @@ DECLARE_SHAPE_FN(mean_sqerr_loss) { REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); @@ -132,7 +132,7 @@ DECLARE_SHAPE_FN(mean_sqerr_loss) { else // in this case output has the same shape as labels and predictions outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - return SHAPELIST(outShapeInfo); + return SHAPELIST(outShapeInfo); } @@ -144,11 +144,11 @@ DECLARE_SHAPE_FN(mean_sqerr_loss) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { - + auto predictions = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto labels = INPUT_VARIABLE(2); - + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels @@ -157,8 +157,8 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients if(reductionMode == 0) reductionMode = 1; - - // inputs validation + + // inputs validation REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); @@ -167,9 +167,9 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // perform weights broadcasting/tile to labels if needed + // perform weights broadcasting/tile to labels if needed auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) + if(!weights->isScalar() && !weights->isSameShape(predictions)) weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); NDArray diff = *predictions - *labels; @@ -178,20 +178,20 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { dLdp->assign(2. * diff); // dE/dp // dE_i/dy_i = -2 * (p_i - y_i) // dLdl->assign(-(*dLdp)); // dE/dl - + NDArray E = diff * diff; - switch (reductionMode) { - + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array *dLdp *= *weightsBroad; - + if(weights->isScalar()) dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -202,40 +202,40 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { NDArray sum; if (weights->isScalar()) sum = (*weights) * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) { - *dLdp = 0.; + *dLdp = 0.; *dLdw = 0.; } else { - + *dLdp *= *weightsBroad / sum; - + if(weights->isScalar()) *dLdw = 0.; else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); + } + else + dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); } break; } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights Nd4jLong numOfNonZeroWeights = 0; if(weights->isScalar()) { if(weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + else + numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); if (numOfNonZeroWeights == 0) { - *dLdp = 0.; + *dLdp = 0.; *dLdw = 0.; } else { @@ -245,14 +245,14 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else dLdw->assign(E / numOfNonZeroWeights); - + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; + *dLdp *= temp; } break; } @@ -262,12 +262,12 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } DECLARE_TYPES(mean_sqerr_loss_grad) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -281,15 +281,15 @@ DECLARE_SHAPE_FN(mean_sqerr_loss_grad) { REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); + DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp index b3f707b23..5b0075466 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp @@ -38,27 +38,27 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" auto labelsSmoothing = T_ARG(0); - // input validation + // input validation REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // perform weights broadcasting/tile to labels if needed + // perform weights broadcasting/tile to labels if needed auto weightsBroad = weights; if(!weights->isScalar() && !weights->isSameShape(logits)) weightsBroad = new NDArray(weights->tileToShape(logits->getShapeInfo())); - + // If labelsSmoothing is nonzero, smooth the labels towards 1/2: auto newLabels = labels; if(labelsSmoothing != 0.) { newLabels = new NDArray(*labels); - newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, newLabels, nullptr); + newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, *newLabels); } - + NDArray E(labels, false, block.launchContext()); // logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits @@ -66,12 +66,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { // multiply E on weights E *= *weightsBroad; - + switch (reductionMode) { case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. output->assign(E); break; - + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array E.reduceNumber(reduce::Sum, *output); break; @@ -80,12 +80,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { NDArray sum; if (weights->isScalar()) sum = (*weights) * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) *output = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / sum); break; } @@ -111,13 +111,13 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { delete weightsBroad; if(newLabels != labels) delete newLabels; - + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sigm_cross_entropy_loss) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -128,11 +128,11 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss) { auto weightsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(2); - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); @@ -142,8 +142,8 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss) { outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); else // in this case output has the same shape as labels and logits outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); + + return SHAPELIST(outShapeInfo); } @@ -155,12 +155,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { auto logits = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto labels = INPUT_VARIABLE(2); - + auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - + NDArray labelsSmoothing = NDArrayFactory::create(logits->dataType(), T_ARG(0), block.launchContext()); int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" @@ -168,27 +168,27 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { if(reductionMode == 0) reductionMode = 1; - // input validation + // input validation REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // perform weights broadcasting/tile to labels if needed + // perform weights broadcasting/tile to labels if needed auto weightsBroad = weights; if(!weights->isScalar() && !weights->isSameShape(logits)) weightsBroad = new NDArray(weights->tileToShape(logits->getShapeInfo())); - + // If labelsSmoothing is nonzero, smooth the labels towards 1/2: auto newLabels = labels; if(labelsSmoothing.e(0) != 0.f) { newLabels = new NDArray(*labels); - newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing.e(0), newLabels, nullptr); + newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing.e(0), *newLabels); } - + NDArray E(labels, false, block.launchContext()); // logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits @@ -196,24 +196,24 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { // dLdp = 1 - labels - 1 / (1 + exp(logits)) helpers::sigmCrossEntropyGrad(block.launchContext(), logits, newLabels, dLdp); - + // dLdl = -logits labelsSmoothing -= 1.f; dLdl->assign(*logits * labelsSmoothing); - switch (reductionMode) { + switch (reductionMode) { case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array *dLdp *= *weightsBroad; *dLdl *= *weightsBroad; - + if(weights->isScalar()) dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } - else + else dLdw->assign(E); break; } @@ -221,9 +221,9 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { NDArray sum; if (weights->isScalar()) sum = (*weights) * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) { *dLdp = 0.; *dLdl = 0.; @@ -234,14 +234,14 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { NDArray temp = *weightsBroad / sum; *dLdp *= temp; *dLdl *= temp; - + if(weights->isScalar()) *dLdw = 0.; else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); - } - else + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); + } + else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum * sum)); } break; @@ -252,8 +252,8 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { if(weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + else + numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); if (numOfNonZeroWeights == 0) { *dLdp = 0.; @@ -267,12 +267,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeightsScalar); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else dLdw->assign(E / numOfNonZeroWeightsScalar); - + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; *dLdp *= temp; *dLdl *= temp; @@ -285,13 +285,13 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { delete weightsBroad; if(newLabels != labels) delete newLabels; - + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sigm_cross_entropy_loss_grad) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -302,11 +302,11 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss_grad) { auto weightsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(2); - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); @@ -314,7 +314,7 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss_grad) { auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace()); auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp index faabc7c18..a1a197fae 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp @@ -54,11 +54,11 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) { // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes // num_classes = labels->sizeAt(1) - auto cLabels = labels->cast(weights->dataType()); - auto newLabels = cLabels; + NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); + NDArray* newLabels = cLabels; if(labelsSmoothing != 0.) { newLabels = new NDArray(cLabels); - *newLabels = (1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1); + newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1)); } // main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last dimension @@ -70,9 +70,9 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) { std::vector dimensions = {-1}; - NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true); - NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log); - NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions); + NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); + NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log); + NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions); // perform weights broadcasting/tile to E if it is necessary auto weightsBroad = weights; @@ -217,25 +217,25 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes // num_classes = labels->sizeAt(1) - auto cLabels = labels->cast(weights->dataType()); - auto newLabels = cLabels; + NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); + NDArray* newLabels = cLabels; if(labelsSmoothing != 0.) { newLabels = new NDArray(labels->getShapeInfo(), dLdl->dataType(), false, block.launchContext()); newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1)); } - NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimensions, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDims(reduce::Sum, dimensions, true); + NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimensions, true)).transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimensions, true); // dEdp = softmax * sum_i(lables_i) - labels - dLdp->assign(softmax * newLabels->reduceAlongDims(reduce::Sum, dimensions, true) - *newLabels); + dLdp->assign(softmax * newLabels->reduceAlongDimension(reduce::Sum, dimensions, true) - *newLabels); // dEdl = -log(softmax) dLdl->assign(-softmax.transform(transform::Log)* (1.f - labelsSmoothing)); - NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true); - NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log); - NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions); + NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); + NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log); + NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions); // perform weights broadcasting/tile to E if it is necessary auto weightsBroad = weights; @@ -253,12 +253,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { *dLdl *= *weights; } else { - dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, weightsBroad); - dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, weightsBroad); + dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, *weightsBroad, *dLdp); + dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, *weightsBroad, *dLdl); if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -289,12 +289,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { else { NDArray temp = *weightsBroad / sum; - dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); - dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); + dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdp); + dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdl); if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -326,12 +326,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { } else { NDArray temp = *weightsBroad / numOfNonZeroWeights; - dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); - dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); + dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdp); + dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdl); if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeights; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp index b129dd483..5e88ec0e6 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp @@ -34,38 +34,38 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) { auto output = OUTPUT_VARIABLE(0); const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1; - - // input validation + + // input validation REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf()); - - std::vector dimension = {classesDim}; - auto maxAlongDim = logits->reduceAlongDims(reduce::Max, {classesDim}, true); + std::vector dimension = {classesDim}; + + auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, {classesDim}, true); auto logExp = (*logits - maxAlongDim).transform(transform::Exp); - auto logSoftMax = ( logExp / logExp.reduceAlongDims(reduce::Sum, {classesDim}, true) ).transform(transform::Log); - - (-(*labels) * logSoftMax).reduceAlongDimension(reduce::Sum, output, dimension); - + auto logSoftMax = ( logExp / logExp.reduceAlongDimension(reduce::Sum, {classesDim}, true) ).transform(transform::Log); + + (-(*labels) * logSoftMax).reduceAlongDimension(reduce::Sum, *output, dimension); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss_with_logits) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits) { - + auto logitsShapeInfo = inputShape->at(0); auto labelsShapeInfo = inputShape->at(1); const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : -1; std::vector dimensions = {classesDim}; - // labels and logits must have the same shapes + // labels and logits must have the same shapes REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); @@ -90,46 +90,46 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, 0) { auto dLdl = OUTPUT_VARIABLE(1); // dL/dlabels const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1; - - // input validation + + // input validation REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf()); - - std::vector dimension = {classesDim}; - NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimension, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDims(reduce::Sum, dimension, true); + std::vector dimension = {classesDim}; + + NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); // dEdp = softmax * sum_i(labels_i) - labels - dLdp->assign(softmax * labels->reduceAlongDims(reduce::Sum, dimension, true) - *labels); + dLdp->assign(softmax * labels->reduceAlongDimension(reduce::Sum, dimension, true) - *labels); + + // dEdl = -log(softmax) + (-softmax).applyTransform(transform::Log, *dLdl); - // dEdl = -log(softmax) - (-softmax).applyTransform(transform::Log, dLdl); - return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss_with_logits_grad) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits_grad) { - auto logitsShapeInfo = inputShape->at(0); + auto logitsShapeInfo = inputShape->at(0); auto labelsShapeInfo = inputShape->at(1); - // labels and logits must have the same shapes + // labels and logits must have the same shapes REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); auto dLdpShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(logitsShapeInfo), shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); auto dLdlShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - + return SHAPELIST(dLdpShapeInfo, dLdlShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp index 4c2da4d0b..e7c8da123 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp @@ -50,9 +50,9 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) std::vector dimension = {-1}; - auto maxAlongDim = logits->reduceAlongDims(reduce::Max, dimension, true); + auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, dimension, true); auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr); - auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDims(reduce::Sum, dimension, true) ).transform(transform::Log)); + auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDimension(reduce::Sum, dimension, true) ).transform(transform::Log)); helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, false); @@ -117,8 +117,8 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, std::vector dimension = {-1}; - NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimension, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDims(reduce::Sum, dimension, true); + NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); // dEdp = softmax - 1 (or 0) dLdp->assign(softmax); diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index 8b6bd24bc..a8cd17131 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -229,19 +229,19 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { // input - mean NDArray xMinusMean(input); // empty array with same shape as input - input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean); + input->applyBroadcast(nd4j::broadcast::Subtract, axes, *mean, xMinusMean); // stdInv NDArray stdInv = *variance + epsilon; - stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) - stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 + stdInv.applyTransform(transform::Reciprocal, stdInv); // 1 / (variance + epsilon) + stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 // dvdm (use dLdM as storage for dvdm) - xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); + xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, *dLdM, excludedAxes, keepUnitiesInShape); *dLdM *= -Ninv; // g_sum - auto gSum = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape); + auto gSum = dLdO->reduceAlongDimension(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape); // dLdB if(applyOffset) @@ -249,11 +249,11 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { // stdInv * (g - g_sum/N) (use dLdI as storage for this expression) gSum *= Ninv; - dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, &gSum, dLdI); - dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, &stdInv); + dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, gSum, *dLdI); + dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, stdInv, *dLdI); // dLdV <- [g*(x - m)]_sum - (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); + (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, *dLdV, excludedAxes, keepUnitiesInShape); // dLdG *dLdV *= stdInv; @@ -265,13 +265,13 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { *dLdV *= -Ninv; // -0.5f * (2 / N); // dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression) - xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dLdM); - xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, dLdV); + xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, *dLdM, xMinusMean); + xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, *dLdV, xMinusMean); // dLdI *dLdI += xMinusMean; if(applyScale) - dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, gamma); + dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, *gamma, *dLdI); *dLdM = 0; // put zeros so far *dLdV = 0; // put zeros so far diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp index 98223c5b4..0652f1840 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp @@ -199,16 +199,16 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNDHWC = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); int trueoD, trueoH, trueoW; // true output depth/height/width ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); + REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); @@ -222,7 +222,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { std::vector gradOaxesForDot; - if(!isNDHWC) { + if(!isNCDHW) { gradOaxesForDot = {0,1,2,3}; // bS, oD, oH, oW input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] @@ -240,7 +240,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { if(gradB) { if(gradB->rankOf() == 2) gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradB, gradOaxesForDot); // sum over bS oD oH oW + gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW if(gradB != OUTPUT_VARIABLE(2)) delete gradB; } @@ -249,7 +249,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] - if(!isNDHWC) { + if(!isNCDHW) { delete input; delete gradI; } @@ -287,7 +287,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNDHWC = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW const int rank = 5; REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); @@ -296,7 +296,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo); int indIOioC, indIiD, indWoC(4); - if(!isNDHWC) { + if(!isNCDHW) { indIOioC = 4; indIiD = 1; } else { diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index 3b8c51bc7..4a5bbd845 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -234,7 +234,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { if(gradB) { if(gradB->rankOf() == 2) gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3}); // sum over bS, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW if(gradB != OUTPUT_VARIABLE(2)) delete gradB; } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp index 1baccbe0e..1b832ea68 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp @@ -244,7 +244,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { if(gradB) { if(gradB->rankOf() == 2) gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW if(gradB != OUTPUT_VARIABLE(2)) delete gradB; } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp index 1a0652462..e18836688 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp @@ -15,7 +15,7 @@ ******************************************************************************/ // -// created by Yurii Shyrma on 08.03.2018 +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -30,11 +30,11 @@ namespace nd4j { namespace ops { CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { - + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC - + auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); @@ -56,14 +56,14 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier - std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, mC}); - REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); + std::vector expectedWeightsShape = {kH, kW, iC, mC}; + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); - + return Status::OK(); } @@ -79,8 +79,8 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC const int rank = 4; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width @@ -97,25 +97,26 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { if(!isNCHW) { indIOioC = 3; indIiH = 1; } - else { + else { indIOioC = 1; indIiH = 2; - } + } - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int mC = weightsShapeInfo[indWmC+1]; // channels multiplier(oC = iC*mC) - const int oC = iC*mC; // output channels + const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size + const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height + const int iW = shape::sizeAt(inputShapeInfo, indIiH+1); // input width + const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels + const int mC = shape::sizeAt(weightsShapeInfo, indWmC); // channels multiplier(oC = iC*mC) + const int oC = iC*mC; // output channels - std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, mC}); - REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + std::vector expectedWeightsShape = {kH, kW, iC, mC}; + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); int oH, oW; // output height, width ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - + Nd4jLong* outputShapeInfo = nullptr; ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); @@ -131,7 +132,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { outputShapeInfo[3] = oW; outputShapeInfo[4] = oC; } - + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); return SHAPELIST(CONSTANT(outputShapeInfo)); @@ -143,14 +144,14 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { ->setAllowedOutputTypes({ALL_FLOATS}); } -////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { - + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] @@ -173,36 +174,35 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier + mC = weights->sizeAt(indWmC); // channels multiplier int trueoH, trueoW; // correct output height, width ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1})); - std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, mC}); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); + std::vector expectedWeightsShape = {kH, kW, iC, mC}; + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); return Status::OK(); } - - +////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(depthwise_conv2d_bp) { Nd4jLong* inputShapeInfo = inputShape->at(0); Nd4jLong* weightsShapeInfo = inputShape->at(1); - Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; - Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); + Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; + Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); const int rank = 4; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo)); + REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo)); int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width @@ -219,26 +219,26 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) { if(!isNCHW) { indIOioC = 3; indIiH = 1; } - else { + else { indIOioC = 1; indIiH = 2; - } + } - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int mC = weightsShapeInfo[indWmC+1]; // channels multiplier(oC = iC*mC) - const int oC = iC*mC; // output channels + const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size + const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height + const int iW = shape::sizeAt(inputShapeInfo, indIiH+1); // input width + const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels + const int mC = shape::sizeAt(weightsShapeInfo, indWmC); // channels multiplier(oC = iC*mC) + const int oC = iC*mC; // output channels int trueoH, trueoW; // correct output height, width ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1})); - std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, mC}); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}); + std::vector expectedWeightsShape = {kH, kW, iC, mC}; + REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); @@ -246,7 +246,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) { if(biasShapeInfo) { Nd4jLong* gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); - } + } return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp index f2215503c..0754000a3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp @@ -32,7 +32,7 @@ namespace ops { ->setAllowedOutputTypes({ALL_FLOATS}); } -CUSTOM_OP_IMPL(fused_batch_norm, 3, 1, false, 0, 2) { +CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) { auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW) auto scale = INPUT_VARIABLE(1); // [iD] auto offset = INPUT_VARIABLE(2); // [iD] @@ -42,35 +42,35 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 1, false, 0, 2) { auto batchVar = OUTPUT_VARIABLE(2); // [iD] const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW - const bool isTraining = (bool)INT_ARG(1); + const bool isTraining = (bool)INT_ARG(1); - REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf()); + REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf()); int bS = x->sizeAt(0); // batch size - int iH, iW, iD; // input height, input width, input depth(number of channels) + int iH, iW, iD; // input height, input width, input depth(number of channels) if(dataFormat) { iD = x->sizeAt(1); iH = x->sizeAt(2); iW = x->sizeAt(3); } else { - iD = x->sizeAt(3); + iD = x->sizeAt(3); iH = x->sizeAt(1); - iW = x->sizeAt(2); - } + iW = x->sizeAt(2); + } REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str()); REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str()); NDArray *mean(nullptr), *variance(nullptr); if(!isTraining){ - mean = INPUT_VARIABLE(3); - variance = INPUT_VARIABLE(4); + mean = INPUT_VARIABLE(3); + variance = INPUT_VARIABLE(4); REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str()); REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str()); } else { - REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); + //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); std::vector shape = {iD}; mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); @@ -78,13 +78,13 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 1, false, 0, 2) { // FIXME: double? double epsilon; - if(block.getTArguments()->size() > 0) + if(block.getTArguments()->size() > 0) epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5; - else + else epsilon = 0.001; - - const int restSize = x->lengthOf() / iD; - auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, x->dataType(), block.launchContext()); + + const int restSize = x->lengthOf() / iD; + auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext()); xAffected.assign(x); const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1; @@ -93,28 +93,28 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 1, false, 0, 2) { const double restSizeAdjust = (double)restSize / restSizeMinusOne; if(isTraining) { - auto sum = xAffected.reduceAlongDims(reduce::Sum, {0}); + auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); sum *= restSizeInv; mean->assign(sum); *batchMean = *mean; //delete sum; } - else + else *batchMean = 0.; - + xAffected -= *mean; - if(isTraining) { + if(isTraining) { int power = 2; - xAffected.applyScalar(scalar::Pow, power); - auto sum = xAffected.reduceAlongDims(reduce::Sum, {0}); + xAffected.applyScalar(scalar::Pow, power, xAffected); + auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); sum *= restSizeInv; variance->assign(sum); *batchVar = (*variance) * restSizeAdjust; //delete sum; } - else - *batchVar = 0.; + else + *batchVar = 0.; xAffected *= (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset); y->assign( xAffected ); @@ -136,13 +136,13 @@ DECLARE_SHAPE_FN(fused_batch_norm) { const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4]; REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str()); - + Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr); - + COPY_SHAPE(xShapeInfo, outShapeInfo); COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo); - COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo); - + COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo); + return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp b/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp index 6dffead8b..f5cc78e2b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp @@ -37,7 +37,7 @@ namespace ops { CONFIGURABLE_OP_IMPL(log_softmax, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - + const int rank = input->rankOf(); const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; @@ -67,8 +67,8 @@ CONFIGURABLE_OP_IMPL(log_softmax_bp, 2, 1, true, 0, 0) { REQUIRE_TRUE(dim < rank, 0, "LOG_SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); helpers::softmax(block.launchContext(), *input, *gradI, dim); - - gradI->assign( *gradO - (*gradI * *gradO).reduceAlongDims(reduce::Sum, {dim}, true) ); + + gradI->assign( *gradO - (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true) ); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp index 3f5c16c17..4e62abc60 100644 --- a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp @@ -31,10 +31,10 @@ namespace nd4j { REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf()); REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf()); REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1)); - REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", + REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", x->sizeAt(1), w->sizeAt(0)); - + auto output = OUTPUT_VARIABLE(0); //T bound = (T)0.f; //nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf()); @@ -46,7 +46,7 @@ namespace nd4j { auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; auto xw = result->at(0); - xw->applyScalar(nd4j::scalar::RELU, scalar, output); + xw->applyScalar(nd4j::scalar::RELU, scalar, *output); return Status::OK(); } @@ -55,7 +55,7 @@ namespace nd4j { auto inShape = inputShape->at(0); auto weightsShape = inputShape->at(1); auto outputShape = ShapeUtils::matrixProductShape(inShape, weightsShape, false, false, ArrayOptions::dataType(inShape), block.getWorkspace()); - + return SHAPELIST(CONSTANT(outputShape)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp index d96f97c10..06bd6d379 100644 --- a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp @@ -38,7 +38,7 @@ namespace ops { CONFIGURABLE_OP_IMPL(softmax, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - + const int rank = input->rankOf(); const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; @@ -59,10 +59,10 @@ CONFIGURABLE_OP_IMPL(softmax_bp, 2, 1, true, 0, 0) { const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; REQUIRE_TRUE(dim < rank, 0, "SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); - + helpers::softmax(block.launchContext(), *input, *gradI, dim); - auto sumAlongDim = (*gradI * *gradO).reduceAlongDims(reduce::Sum, {dim}, true); + auto sumAlongDim = (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true); gradI->assign(*gradI * (*gradO - sumAlongDim)); return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp index 27b6a4302..65f01cf6c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp @@ -56,7 +56,7 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) { axes[i] = i; // mean as reduction for last dimension set - auto mean = input->reduceAlongDims(reduce::Mean, axes); + auto mean = input->reduceAlongDimension(reduce::Mean, axes); // this is contrast calculation output->assign((*input - mean) * (*factor) + mean); @@ -104,13 +104,13 @@ CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) { std::vector axes({1}); // dim 1 of pseudoresult // mean as reduction for last dimension set over size (dim 1) of result3D - auto mean = input3D.reduceAlongDims(reduce::Mean, axes); + auto mean = input3D.reduceAlongDimension(reduce::Mean, axes); // result as (x - mean) * factor + mean auto temp = input3D.ulike(); - input3D.applyBroadcast(broadcast::Subtract, {0, 2}, &mean, &temp, nullptr); - temp.applyScalarArr(scalar::Multiply, factor); - temp.applyBroadcast(broadcast::Add, {0, 2}, &mean, &output3D); + input3D.applyBroadcast(broadcast::Subtract, {0, 2}, mean, temp); + temp.applyScalarArr(scalar::Multiply, *factor, temp); + temp.applyBroadcast(broadcast::Add, {0, 2}, mean, output3D); output->assign(output3D); if(block.width() == 1) delete factor; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp index 32e51bdb9..003ff6e75 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp @@ -39,10 +39,14 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) { return Status::OK(); const int rank = input->rankOf(); - const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + const int arg_size = block.getIArguments()->size(); + const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !"); REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); + if (arg_size > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); NDArray* delta = nullptr; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp index de947c9ae..0a8eaf0c3 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp @@ -38,9 +38,13 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) { return Status::OK(); const int rank = input->rankOf(); - const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + const int arg_size = block.getIArguments()->size(); + const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); + if (arg_size > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !"); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp index 3fd5e2250..10e036b61 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp @@ -44,11 +44,11 @@ namespace nd4j { auto axisVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axisVector, axis); - input->applyIndexReduce(indexreduce::IndexMax, output, axis); + input->applyIndexReduce(indexreduce::IndexMax, *output, axis); } else { helpers::adjustAxis(input->rankOf(), axis); - input->applyIndexReduce(indexreduce::IndexMax, output, axis); + input->applyIndexReduce(indexreduce::IndexMax, *output, axis); } STORE_RESULT(output); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp index 91e9d5a41..554b7b95b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp @@ -44,11 +44,11 @@ namespace nd4j { auto axisVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axisVector, axis); - input->applyIndexReduce(indexreduce::IndexMin, output, axis); + input->applyIndexReduce(indexreduce::IndexMin, *output, axis); } else { helpers::adjustAxis(input->rankOf(), axis); - input->applyIndexReduce(indexreduce::IndexMin, output, axis); + input->applyIndexReduce(indexreduce::IndexMin, *output, axis); } STORE_RESULT(output); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp index b43895a31..0c88a9c53 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp @@ -82,7 +82,7 @@ CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) { gradI->assign(gradO); - gradO->reduceAlongDimension(nd4j::reduce::Sum, gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim})); + gradO->reduceAlongDimension(nd4j::reduce::Sum, *gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim})); return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp index fc928c3cd..822b4b91b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp @@ -45,7 +45,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) { v = i++; } - std::unique_ptr outputView(output->allTensorsAlongDimension(dims)); + ResultSet outputView = output->allTensorsAlongDimension(dims); REQUIRE_TRUE(block.width() > output->sizeAt(0), 0, "embedding_lookup: input list should be greater then %i, but %i given.", output->sizeAt(0), block.width() ); @@ -53,7 +53,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) { Nd4jLong thisIndex = (*indeces).e(e); input = INPUT_VARIABLE(thisIndex); // lookup param - outputView->at(e)->assign(input); + outputView.at(e)->assign(input); } } else { @@ -87,7 +87,7 @@ DECLARE_SHAPE_FN(embedding_lookup) { int inRank = shape::rank(inShapeInfo); if (inputShape->size() == 2u) { int outRank = inRank; - + std::vector shapeInfo(outRank); shapeInfo[0] = indecesShapeInfo[1]; // vector - how many elements @@ -98,14 +98,14 @@ DECLARE_SHAPE_FN(embedding_lookup) { return SHAPELIST(outShapeInfo); } - - int outRank = inRank + 1; + + int outRank = inRank + 1; std::vector shapeInfo(outRank); auto indeces = INPUT_VARIABLE(block.width() - 1); shapeInfo[0] = indeces->lengthOf(); // vector - how many elements for (int e = 1; e < outRank; e++) shapeInfo[e] = shape::sizeAt(inShapeInfo, e); - + auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo); return SHAPELIST(outShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lgamma.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/lgamma.cpp new file mode 100644 index 000000000..615190c2f --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/lgamma.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#include +#if NOT_EXCLUDED(OP_lgamma) + +#include +#include + +namespace nd4j { +namespace ops { + +OP_IMPL(lgamma, 1, 1, true) { + + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + helpers::lgamma(block.launchContext(), *x, *z); + + return Status::OK(); +} + +DECLARE_TYPES(lgamma) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) // as TF says + ->setSameMode(true); +} + +} +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp new file mode 100644 index 000000000..e0e960159 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by GS at 12/10/2019 +// + +#include +#if NOT_EXCLUDED(OP_matrix_inverse) + +#include +#include +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + auto p = OUTPUT_VARIABLE(1); + if (block.getIArguments()->size()) { + DataType dtype = (DataType)INT_ARG(0); + REQUIRE_TRUE(dtype == nd4j::DataType::INT32 || dtype == nd4j::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); } + + REQUIRE_TRUE(input->rankOf() >=2, 0, "lu: The rank of input array should not less than 2, but %i is given", input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "lu: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); + + helpers::lu(block.launchContext(), input, z, p); + return Status::OK(); + } + + DECLARE_SHAPE_FN(lu) { + auto in = inputShape->at(0); + auto shapeVector = ShapeUtils::shapeAsVector(in); + auto luShape = ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace()); + auto dtype = nd4j::DataType::INT32; + if (block.getIArguments()->size()) { + dtype = (DataType)INT_ARG(0); + REQUIRE_TRUE(dtype == nd4j::DataType::INT32 || dtype == nd4j::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); + } + auto luP = ShapeBuilders::createShapeInfo(dtype, shape::order(in), shapeVector.size() - 1, + shapeVector.data(), block.workspace()); + return SHAPELIST(CONSTANT(luShape), CONSTANT(luP)); + } + + DECLARE_TYPES(lu) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {nd4j::DataType::INT32, nd4j::DataType::INT64}) + ->setSameMode(false); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/moments.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/moments.cpp index 12b6c9e07..5e76fefec 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/moments.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/moments.cpp @@ -49,8 +49,8 @@ namespace nd4j { } std::vector& dims = axis; - input->varianceAlongDimension(variance::SummaryStatsVariance, variances, false, axis); - input->reduceAlongDimension(reduce::Mean, means, axis, keepDims); + input->varianceAlongDimension(variance::SummaryStatsVariance, *variances, false, axis); + input->reduceAlongDimension(reduce::Mean, *means, axis, keepDims); return Status::OK(); } @@ -74,10 +74,10 @@ namespace nd4j { } //std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - + auto meanShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); auto varianceShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); - return SHAPELIST(meanShape, varianceShape); + return SHAPELIST(meanShape, varianceShape); } DECLARE_TYPES(moments) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/norm.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/norm.cpp index 983f18bd9..e74a28184 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/norm.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/norm.cpp @@ -52,31 +52,31 @@ namespace nd4j { case 0: { REQUIRE_TRUE(dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, "Norm: Frobenius is defined for 2D matrices or TADS only"); // fro - input->reduceAlongDimension(reduce::NormFrobenius, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2); } break; case 1: { // euclidean if ((input->rankOf() == 2 && dims.size() == 0) || dims.size() == 2) { - input->reduceAlongDimension(reduce::NormFrobenius, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2); } else { - input->reduceAlongDimension(reduce::Norm2, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2); } } break; case 2: { // 1 - input->reduceAlongDimension(reduce::Norm1, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::Norm1, *output, dims, false, output->rankOf() == 2); } break; case 3: { - // 2 - input->reduceAlongDimension(reduce::Norm2, output, dims, false, output->rankOf() == 2); + // 2 + input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2); } break; case 4: { // inf-norm - input->reduceAlongDimension(reduce::NormMax, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::NormMax, *output, dims, false, output->rankOf() == 2); } break; default: { @@ -84,7 +84,7 @@ namespace nd4j { REQUIRE_TRUE(block.getIArguments()->size() > 1, 0, "P-Norm reductions requires 2 TArguments, but only 1 was provided"); // FIXME: p is required here //T p = T_ARG(1); - input->reduceAlongDimension(reduce::NormP, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::NormP, *output, dims, false, output->rankOf() == 2); } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp index 23fd9a79e..15f295995 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp @@ -40,23 +40,20 @@ namespace nd4j { shift.assign(T_ARG(0)); } - means->applyScalarArr(scalar::Divide, counts, resMeans, nullptr); + means->applyScalarArr(scalar::Divide, *counts, *resMeans); - NDArray* squareMeans = resMeans->dup('c'); - NDArray* tempVariances = resVariances->dup('c'); + NDArray squareMeans = resMeans->dup('c'); + NDArray tempVariances = resVariances->dup('c'); - squareMeans->applyTransform(transform::Square, squareMeans, nullptr); - variances->applyScalarArr(scalar::Divide, counts, tempVariances, nullptr); -// tempVariances->printIndexedBuffer("varianced divided by count"); - tempVariances->applyPairwiseTransform(pairwise::Subtract, squareMeans, resVariances, nullptr); + squareMeans.applyTransform(transform::Square, squareMeans, nullptr); + variances->applyScalarArr(scalar::Divide, *counts, tempVariances); +// tempVariances.printIndexedBuffer("varianced divided by count"); + tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances); if (shift.e(0) != 0) { - resMeans->applyScalarArr(scalar::Add, &shift, resMeans, nullptr); + resMeans->applyScalarArr(scalar::Add, shift, *resMeans); } - delete squareMeans; - delete tempVariances; - return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/qr.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/qr.cpp new file mode 100644 index 000000000..32247a1cd --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/qr.cpp @@ -0,0 +1,88 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by GS at 12/20/2019 +// + +#include +#include +#include + +#if NOT_EXCLUDED(OP_qr) +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(qr, 1, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto outputQ = OUTPUT_VARIABLE(0); + auto outputR = OUTPUT_VARIABLE(1); + auto fullMatricies = false; + if (block.getBArguments()->size()) + fullMatricies = B_ARG(0); + REQUIRE_TRUE(input->rankOf() >=2, 0, "qr: The rank of input array should not be less than 2, but %i is given", input->rankOf()); + REQUIRE_TRUE((fullMatricies && outputQ->sizeAt(-1) == input->sizeAt(-2)) || (!fullMatricies && outputQ->isSameShape(input)), 0, "qr: The last dimmensions should be equal to result Q, but %i and %i are given", outputQ->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE((fullMatricies && outputR->sizeAt(-1) == input->sizeAt(-1)) || (!fullMatricies && outputR->sizeAt(-1) == outputR->sizeAt(-2)), 0, "qr: The last dimmensions should be equal to result R, but %i and %i are given", outputR->sizeAt(-1), input->sizeAt(-1)); + + helpers::qr(block.launchContext(), input, outputQ, outputR, fullMatricies); + return Status::OK(); + } + + DECLARE_SHAPE_FN(qr) { + auto inShape = inputShape->at(0); + + Nd4jLong* shapeQ; + Nd4jLong* shapeR; + int targetRank = shape::rank(inShape); // last two dimensions will be reduced to scalar + + auto fullMatricies = false; + if (block.getBArguments()->size()) + fullMatricies = B_ARG(0); + + auto shape = ShapeUtils::shapeAsVector(inShape); + + if (!fullMatricies) { // outputs are: Q is MxN and R is NxN + shape[targetRank - 1] = shape::sizeAt(inShape, -1); + shape[targetRank - 2] = shape[targetRank - 1]; + shapeQ = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), + shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + shapeR = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), + shape::order(inShape), shape); + + } + else {// otherwise outputs are Q is MxM and R is MxN with zero filled rows + shape[targetRank - 1] = shape::sizeAt(inShape, -2); + shape[targetRank - 2] = shape[targetRank - 1]; + shapeR = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), + shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + shapeQ = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), + shape::order(inShape), shape); + } + + return SHAPELIST(shapeQ, shapeR); + + } + + DECLARE_TYPES(qr) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); + } + } +} + +#endif diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduceMean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduceMean.cpp index 0beec605a..f83994606 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduceMean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduceMean.cpp @@ -47,7 +47,7 @@ CUSTOM_OP_IMPL(reduce_mean, 1, 1, false, 0, 0) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - input->reduceAlongDimension(reduce::Mean, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Mean, *output, dimensions, keepDims); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduceStDev.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduceStDev.cpp index f1ebf91d1..6a3e7c050 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduceStDev.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduceStDev.cpp @@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(reduce_stdev, 1, 1, false, 0, 0) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, output, biasCorrected, dimensions); + input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, *output, biasCorrected, dimensions); return Status::OK(); } @@ -130,10 +130,10 @@ CUSTOM_OP_IMPL(reduce_stdev_bp, 2, 1, false, 0, 0) { const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; - auto mean = input->reduceAlongDims(reduce::Mean, dimensions, true); + auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); NDArray variance(mean.getShapeInfo(), true, block.launchContext()); // create empty array with shape matching shape of mean array - input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, &variance, biasCorrected, dimensions); + input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, variance, biasCorrected, dimensions); gradI->assign( (*input - mean) / (variance * NminusOne)); // automatic broadcasting happens here diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduceVariance.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduceVariance.cpp index dbf470935..16bfdc8a9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduceVariance.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduceVariance.cpp @@ -54,8 +54,8 @@ CUSTOM_OP_IMPL(reduce_variance, 1, 1, false, 0, 0) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - input->varianceAlongDimension(variance::SummaryStatsVariance, output, biasCorrected, dimensions); + + input->varianceAlongDimension(variance::SummaryStatsVariance, *output, biasCorrected, dimensions); return Status::OK(); } @@ -77,7 +77,7 @@ DECLARE_SHAPE_FN(reduce_variance) { } REQUIRE_TRUE(dimensions.size() <= INPUT_VARIABLE(0)->rankOf(), 0, "REDUCE_VARIANCE OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - + for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); @@ -128,9 +128,9 @@ CUSTOM_OP_IMPL(reduce_variance_bp, 2, 1, false, 0, 0) { const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; const double factor1 = 2.0 / NminusOne; const double factor2 = 2.0 / (N * NminusOne); - - auto mean = input->reduceAlongDims(reduce::Mean, dimensions, true); - + + auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); + gradI->assign( (*input - mean) * (2.0f / NminusOne)); // automatic broadcasting happens here if(!keepDims) { @@ -153,13 +153,13 @@ DECLARE_SHAPE_FN(reduce_variance_bp) { } REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - + for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - + Nd4jLong* gradIshapeInfo(nullptr); COPY_SHAPE(in, gradIshapeInfo); - + return SHAPELIST(CONSTANT(gradIshapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_logsumexp.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_logsumexp.cpp index 0cf0e1f9e..a02b4db9b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_logsumexp.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_logsumexp.cpp @@ -45,9 +45,9 @@ namespace ops { //void* whereMax = (void*)(); auto internal = (*input); internal -= maxVals; - internal.applyTransform(transform::Exp, nullptr, nullptr); - internal.reduceAlongDimension(reduce::Sum, output, axes, keepDims, false); //, (void*)&maxVals); - output->applyTransform(transform::Log, nullptr, nullptr); + internal.applyTransform(transform::Exp, internal); + internal.reduceAlongDimension(reduce::Sum, *output, axes, keepDims, false); //, (void*)&maxVals); + output->applyTransform(transform::Log, *output); (*output) += maxVals; return ND4J_STATUS_OK; } @@ -56,7 +56,7 @@ namespace ops { -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) -> setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(reduce_logsumexp) { + DECLARE_SHAPE_FN(reduce_logsumexp) { const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; auto input = INPUT_VARIABLE(0); @@ -74,6 +74,6 @@ namespace ops { return SHAPELIST(outShapeInfo); } -#endif +#endif } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp index 4ab9954b0..870017e8d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp @@ -52,7 +52,7 @@ CUSTOM_OP_IMPL(reduce_max, 1, 1, false, 0, 0) { else if (block.getTArguments()->size() > 0) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Max, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Max, *output, dimensions, keepDims); return Status::OK(); } @@ -122,8 +122,7 @@ CUSTOM_OP_IMPL(reduce_max_bp, 2, 1, false, 0, 0) { else { auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMax, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, *indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation - delete indicesArr; + helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_min.cpp index cb9b9e21b..e8b073de8 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_min.cpp @@ -52,7 +52,7 @@ CUSTOM_OP_IMPL(reduce_min, 1, 1, false, 0, 0) { else if (block.getTArguments()->size() > 0) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Min, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Min, *output, dimensions, keepDims); return Status::OK(); } @@ -89,7 +89,7 @@ DECLARE_TYPES(reduce_min) { } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_min_bp) @@ -125,8 +125,7 @@ CUSTOM_OP_IMPL(reduce_min_bp, 2, 1, false, 0, 0) { else { auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMin, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, *indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation - delete indicesArr; + helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm1.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm1.cpp index 8da05c3f4..172f3df8e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm1.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm1.cpp @@ -51,7 +51,7 @@ CUSTOM_OP_IMPL(reduce_norm1, 1, 1, false, 0, 0) { else if (block.getTArguments()->size()) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Norm1, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Norm1, *output, dimensions, keepDims); return Status::OK(); } @@ -85,7 +85,7 @@ DECLARE_TYPES(reduce_norm1) { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes({ALL_FLOATS}); } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_norm1_bp) ////////////////////////////////////////////////////////////////////////// @@ -100,7 +100,7 @@ CUSTOM_OP_IMPL(reduce_norm1_bp, 2, 1, false, 0, 0) { auto gradO = INPUT_VARIABLE(1); auto gradI = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::Sign, gradI); + input->applyTransform(nd4j::transform::Sign, *gradI); if (gradO->lengthOf() == 1) { *gradI *= *gradO; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm2.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm2.cpp index 1a7e0a911..e54518359 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm2.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm2.cpp @@ -50,7 +50,7 @@ CUSTOM_OP_IMPL(reduce_norm2, 1, 1, false, 0, 0) { else if (block.getTArguments()->size()) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Norm2, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Norm2, *output, dimensions, keepDims); return Status::OK(); } @@ -124,7 +124,7 @@ CUSTOM_OP_IMPL(reduce_norm2_bp, 2, 1, false, 0, 0) { // *** calculations *** // - *gradI /= input->reduceAlongDims(reduce::Norm2, dimensions, true); + *gradI /= input->reduceAlongDimension(reduce::Norm2, dimensions, true); if(!keepDims) { auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm_max.cpp index 902b1d699..c71310947 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm_max.cpp @@ -52,7 +52,7 @@ CUSTOM_OP_IMPL(reduce_norm_max, 1, 1, false, 0, 0) { else if (block.getTArguments()->size()) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::NormMax, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::NormMax, *output, dimensions, keepDims); return Status::OK(); } @@ -87,7 +87,7 @@ DECLARE_TYPES(reduce_norm_max) { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes({ALL_FLOATS}); } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_norm_max_bp) @@ -124,9 +124,8 @@ CUSTOM_OP_IMPL(reduce_norm_max_bp, 2, 1, false, 0, 0) { else { auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexAbsoluteMax, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, *indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation + helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation *gradI *= input->transform(nd4j::transform::Sign); - delete indicesArr; } return Status::OK(); @@ -139,7 +138,7 @@ DECLARE_SHAPE_FN(reduce_norm_max_bp) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM_MAX_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); for(const auto& item : dimensions) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_prod.cpp index 7f3afc1c6..965b6dcaa 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_prod.cpp @@ -51,7 +51,7 @@ CUSTOM_OP_IMPL(reduce_prod, 1, 1, false, 0, 0) { else if (block.getTArguments()->size()) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Prod, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Prod, *output, dimensions, keepDims); return Status::OK(); } @@ -123,8 +123,8 @@ CUSTOM_OP_IMPL(reduce_prod_bp, 2, 1, false, 0, 0) { // *** calculations *** // - auto products = input->reduceAlongDims(reduce::Prod, dimensions, true); - gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &products, gradI); + auto products = input->reduceAlongDimension(reduce::Prod, dimensions, true); + gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), products, *gradI); *gradI /= *input; if(!keepDims) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sqnorm.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sqnorm.cpp index 00d277ec7..e42050ff6 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sqnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sqnorm.cpp @@ -50,7 +50,7 @@ CUSTOM_OP_IMPL(reduce_sqnorm, 1, 1, false, 0, 0) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_SQNORM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - input->reduceAlongDimension(reduce::SquaredNorm, gradI, dimensions, keepDims); + input->reduceAlongDimension(reduce::SquaredNorm, *gradI, dimensions, keepDims); return Status::OK(); } @@ -86,7 +86,7 @@ DECLARE_TYPES(reduce_sqnorm) { ->setAllowedOutputTypes({ALL_FLOATS}); } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_sqnorm_bp) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sum.cpp index 4631e4807..522164593 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sum.cpp @@ -51,7 +51,7 @@ CUSTOM_OP_IMPL(reduce_sum, 1, 1, false, 0, 0) { else if (block.getTArguments()->size()) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Sum, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Sum, *output, dimensions, keepDims); return Status::OK(); } @@ -85,7 +85,7 @@ DECLARE_TYPES(reduce_sum) { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setSameMode(true); } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_sum_bp) ////////////////////////////////////////////////////////////////////////// @@ -123,9 +123,9 @@ CUSTOM_OP_IMPL(reduce_sum_bp, 2, 1, false, 0, 0) { if(!keepDims) { auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); auto r = gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &r, gradI); + gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), r, *gradI); } else - gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), gradO, gradI); + gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), *gradO, *gradI); } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp new file mode 100644 index 000000000..b0f637c45 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp @@ -0,0 +1,122 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_resize_area) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(resize_area, 1, 1, false, 0, -2) { + + auto image = INPUT_VARIABLE(0); + int width; + int height; + + if (block.width() == 2) { + auto size = INPUT_VARIABLE(1); // integer vector with shape {2} and content (new_height, new_width) + REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_area: Resize params is a pair of values, not %i.", size->lengthOf()); + size->syncToHost(); + width = size->e(1); + height = size->e(0); + } + else { + REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params already given by the second param. Int params are expensive."); + width = INT_ARG(1); + height = INT_ARG(0); + } + + auto output = OUTPUT_VARIABLE(0); + if (output->isEmpty()) return Status::OK(); + auto inRank = image->rankOf(); + + REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank); + REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_area: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf()); + REQUIRE_TRUE(width > 0 , 0, "resize_area: picture width should be positive 32 bit integer, but %i given", width); + REQUIRE_TRUE(height > 0 , 0, "resize_area: picture height should be positive 32 bit integer, but %i given", height); + REQUIRE_TRUE(image->lengthOf() > 0, 0, "resize_area: Only non-zero images allowed to processing."); + + auto alignCorners = false; + if (block.numB() > 0) { + alignCorners = B_ARG(0); + } + + auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); + auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); + + return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target); + } + + DECLARE_SHAPE_FN(resize_area) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + + Nd4jLong* outputShape; + auto inRank = shape::rank(in); + int width; + int height; + if (block.width() == 2) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_area: Resize params already given by the second param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); + } + else { + REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor."); + width = INT_ARG(1); + height = INT_ARG(0); + } + + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank); + + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + } + else { + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[3]; + } + ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); + + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; + } + DECLARE_TYPES(resize_area) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedOutputTypes({DataType::FLOAT32}); + } + + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp index da98c1702..26ca7eec9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2019 Konduit K.K. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -76,8 +76,8 @@ namespace nd4j { int width; int height; auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bicubic: Resize params already given by the second param. Int params are expensive."); width = newImageSize->e(0); height = newImageSize->e(1); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp index f1f79b08f..652b78cf1 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp index 6c18e61e1..db477f569 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,7 +20,7 @@ // #include -#if NOT_EXCLUDED(OP_resize_bilinear) +#if NOT_EXCLUDED(OP_resize_nearest_neighbor) //#include #include @@ -54,7 +55,7 @@ namespace nd4j { if (block.numB() > 1) halfPixelCenter = B_ARG(1); - REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbour: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width); + REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbor: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width); REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured"); REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str()); @@ -73,7 +74,7 @@ namespace nd4j { auto inRank = shape::rank(in); Nd4jLong* outputShape; - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D " + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: input image should be 4D " "tensor, but input has rank %i", inRank); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp index 7a5753edc..fbff41c47 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp @@ -29,7 +29,7 @@ namespace nd4j { auto x = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - x->applyTransform(transform::Rint, z); + x->applyTransform(transform::Rint, *z); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp index 61f592f1d..f93fc198e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp @@ -84,7 +84,7 @@ namespace ops { if (block.isInplace()) output = input; - shiftIsLinear = axes.size() == 0; + shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1); if (shiftIsLinear) { helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace()); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp index b5014bb7b..34da37897 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto output = OUTPUT_VARIABLE(0); int extras = 2; - input->applyScalar(scalar::Pow, extras, output); + input->applyScalar(scalar::Pow, extras, *output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp index a43637788..81f81c326 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp @@ -29,7 +29,7 @@ namespace nd4j { auto x = INPUT_VARIABLE(0); auto out = OUTPUT_VARIABLE(0); // just for lulz - x->applyTransform(transform::Identity, out, nullptr); + x->applyTransform(transform::Identity, *out); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sufficient_statistics.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/sufficient_statistics.cpp index 63aa80e0a..ed7698d15 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/sufficient_statistics.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/sufficient_statistics.cpp @@ -38,8 +38,8 @@ namespace nd4j { // axis might be dynamic (i.e. tf mode) helpers::adjustAxis(input->rankOf(), axisVector, axis); - input->reduceAlongDimension(reduce::SquaredNorm, squares, axis); - input->reduceAlongDimension(reduce::Sum, sum, axis); + input->reduceAlongDimension(reduce::SquaredNorm, *squares, axis); + input->reduceAlongDimension(reduce::Sum, *sum, axis); auto count = NDArrayFactory::create(input->dataType(), input->lengthOf() / sum->lengthOf()); dataCount->assign(count); if (block.numT() > 0) { @@ -79,7 +79,7 @@ namespace nd4j { auto shapeList = SHAPELIST(scalarShape, sumShape, squareShape); if (block.numT() > 0) shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0)))); - + return shapeList; } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/tear.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/tear.cpp index 090c29504..c76435622 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/tear.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/tear.cpp @@ -38,18 +38,17 @@ namespace nd4j { REQUIRE_TRUE(v >= 0 && v < input->rankOf(), 0, "Tear dimensions should be non-negative values, and lower then input rank. Got %i instead", v); auto tads = input->allTensorsAlongDimension(dims); - for (Nd4jLong e = 0; e < tads->size(); e++) { + for (Nd4jLong e = 0; e < tads.size(); e++) { auto outE = OUTPUT_VARIABLE(e); - outE->assign(tads->at(e)); + outE->assign(tads.at(e)); // just for debugging purposes this->storeResult(block, e, *outE); } - delete tads; - return Status::OK(); } + DECLARE_SHAPE_FN(tear) { auto inShape = inputShape->at(0); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/triangular_solve.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/triangular_solve.cpp new file mode 100644 index 000000000..181f47d3d --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/triangular_solve.cpp @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by GS at 01/14/2020 +// + +#include +#if NOT_EXCLUDED(OP_triangual_solve) + +#include +#include +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(triangular_solve, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + bool isLower = true; + bool useAdjoint = false; + + if (block.numB() > 0) { + if (block.numB() > 1) { + isLower = B_ARG(0); + useAdjoint = B_ARG(1); + } + else { + isLower = B_ARG(0); + } + } + + REQUIRE_TRUE(a->rankOf() >=2, 0, "triangular_solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); + REQUIRE_TRUE(b->rankOf() >=2, 0, "triangular_solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); + + REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "triangular_solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "triangular_solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); + auto input = a; + if (useAdjoint) { + auto adjointA = a->ulike(); + helpers::adjointMatrix(block.launchContext(), a, isLower, &adjointA); + input = new NDArray(adjointA); //.detach(); + isLower = !isLower; + }; + + auto res = helpers::triangularSolveFunctor(block.launchContext(), input, b, isLower, useAdjoint, z); + if (input != a) + delete input; + + return Status::OK(); + } + + DECLARE_SHAPE_FN(triangular_solve) { + auto in0 = inputShape->at(1); + auto in1 = inputShape->at(1); + auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); + + return SHAPELIST(CONSTANT(luShape)); + } + + DECLARE_TYPES(triangular_solve) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp index f6ac319ab..a44510104 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp @@ -52,21 +52,20 @@ namespace nd4j { } auto tads = input->allTensorsAlongDimension(dims); - //nd4j_printf("Tad size: %d\n",tads->size()); - for (int e = 0; e < tads->size(); e++) { + //nd4j_printf("Tad size: %d\n",tads.size()); + for (int e = 0; e < tads.size(); e++) { //nd4j_printf("Calling assign at index %d\n",e); auto outE = OUTPUT_VARIABLE(e); - auto tadAtE = tads->at(e); + auto tadAtE = tads.at(e); outE->assign(tadAtE); this->storeResult(block, e, *outE); } - delete tads; - return Status::OK(); } + DECLARE_SYN(unpack, unstack); DECLARE_SHAPE_FN(unstack) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/xw_plus_b.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/xw_plus_b.cpp index 4e86690b4..ce68df1a0 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/xw_plus_b.cpp @@ -41,7 +41,7 @@ namespace nd4j { MmulHelper::mmul(x, y, z, 1.0, 0.0); // adding b vector - z->addiRowVector(b); + z->addiRowVector(*b); return Status::OK(); } @@ -49,7 +49,7 @@ namespace nd4j { DECLARE_SHAPE_FN(xw_plus_b) { auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), inputShape->at(1), false, false, ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace()); - + return SHAPELIST(CONSTANT(outputShape)); } diff --git a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp new file mode 100644 index 000000000..bbdee17f4 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp @@ -0,0 +1,114 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#if NOT_EXCLUDED(OP_random_multinomial) + +#include +#include +#include + +namespace nd4j { + namespace ops { + /////////////////////// + /** + * multinomial (categorical) random generator + * takes 2D ndarray with logits with shape [batch_size (N), num_classes (K)] + * and array with one scalar value of samples number, number of independent samples to draw for each experiment 1,N. + * represents the unnormalized log-probabilities for all classes. + * Int arguments: 0 - optional argument, corresponds to dimension with batch_size + * Int arguments: 1 - optional argument, integer type to use for the output. Default int64. + */ + // used https://en.wikipedia.org/wiki/Categorical_distribution + // methods: gumbel trick + softmax + argmax + CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + auto inputSamples = INPUT_VARIABLE(1); + + + REQUIRE_TRUE(!input->isEmpty(), 0, "RANDOM_MULTINOMIAL OP: Have to be provided at least one logits. "); + + REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," + " but got no argumets instead."); + + Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); + // do nothing if number of samples = 0 + if (0 == numOfSamples) + return Status::OK(); + + REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples); + + const int rank = input->rankOf(); + REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank); + + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + auto dimA = (0 == dimC) ? 1 : 0; + if (1 == input->sizeAt(dimA)) { + *output = 0; + return Status::OK(); + } + + auto rng = block.randomGenerator(); + helpers::fillRandomMultiNomial(block.launchContext(), rng, *input, *output, numOfSamples, dimC); + return Status::OK(); + } + + + DECLARE_SHAPE_FN(random_multinomial) { + + auto input = INPUT_VARIABLE(0); + auto inputSamples = INPUT_VARIABLE(1); + + REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," + " but got no argumets instead."); + + Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); + + REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples); + + const int rank = input->rankOf(); + REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank); + + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + auto nShape = input->getShapeAsVector(); + auto dimA = (0 == dimC) ? 1 : 0; + nShape[dimA] = numOfSamples; + + DataType nType = (argSize > 1) ? ( INT_ARG(1) >= 0 ? static_cast(INT_ARG(1)) : nd4j::DataType::INT64) : nd4j::DataType::INT64; + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(nType, input->ordering(), nShape)); + } + + DECLARE_TYPES(random_multinomial) { + getOpDescriptor() + ->setAllowedInputTypes(0, { ALL_FLOATS, ALL_INTS }) + ->setAllowedInputTypes(1, { nd4j::DataType::INT32 }) + ->setAllowedOutputTypes(0, { ALL_INDICES }); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/recurrent/sru.cpp b/libnd4j/include/ops/declarable/generic/recurrent/sru.cpp index 7754844d2..6ca57d297 100644 --- a/libnd4j/include/ops/declarable/generic/recurrent/sru.cpp +++ b/libnd4j/include/ops/declarable/generic/recurrent/sru.cpp @@ -73,7 +73,7 @@ CUSTOM_OP_IMPL(sru, 5, 2, false, 0, 0) { auto xm = x; if(mask) { xm = new NDArray(x->getShapeInfo(), true, block.launchContext()); - x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xm, nullptr); + x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *xm); } // time loop @@ -180,7 +180,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) { // x = x * mask if(applyMask) - x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr); // apply mask + x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *x); // apply mask // multiplication matrix wi = matmul(w,x), U = WX auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x N] @@ -226,52 +226,52 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) { ///////////////// forward // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) - ft.addRowVector(&bF, &ft); - rt.addRowVector(&bR, &rt); - ft.applyTransform(transform::Sigmoid, nullptr, nullptr); - rt.applyTransform(transform::Sigmoid, nullptr, nullptr); + ft.addRowVector(bF, ft); + rt.addRowVector(bR, rt); + ft.applyTransform(transform::Sigmoid, ft); + rt.applyTransform(transform::Sigmoid, rt); // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); - ct.applyTransform(transform::Tanh, gct); + ct.applyTransform(transform::Tanh, *gct); // ftMinus = 1-ft, rtMinus = 1-rt - ft.applyTransform(transform::OneMinus, ftMinus); - rt.applyTransform(transform::OneMinus, rtMinus); + ft.applyTransform(transform::OneMinus, *ftMinus); + rt.applyTransform(transform::OneMinus, *rtMinus); ///////////////// backward // bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; - gct->applyPairwiseTransform(pairwise::Subtract, &xt, temp1, nullptr); // temp1 = (g_ct - xt) - rtMinus->applyPairwiseTransform(pairwise::Multiply, &rt, temp2, nullptr); // temp2 = (1.0f - rt) * rt; - temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, nullptr); // temp1 = (g_ct - xt) * (1.0f - rt) * rt; - inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, &gradBRt, nullptr); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; + gct->applyPairwiseTransform(pairwise::Subtract, xt, *temp1); // temp1 = (g_ct - xt) + rtMinus->applyPairwiseTransform(pairwise::Multiply, rt, *temp2); // temp2 = (1.0f - rt) * rt; + temp1->applyPairwiseTransform(pairwise::Multiply, *temp2); // temp1 = (g_ct - xt) * (1.0f - rt) * rt; + inGradHt.applyPairwiseTransform(pairwise::Multiply, *temp1, gradBRt); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; // bF, TODO - tanh // gradTanh = (1.0f - g_ct * g_ct); - gct->applyPairwiseTransform(pairwise::Multiply, gct, gradTanh, nullptr); // gradTanh = g_ct * g_ct - gradTanh->applyTransform(transform::OneMinus, gradTanh); // gradTanh = (1.0f - g_ct * g_ct) + gct->applyPairwiseTransform(pairwise::Multiply, *gct, *gradTanh); // gradTanh = g_ct * g_ct + gradTanh->applyTransform(transform::OneMinus, *gradTanh); // gradTanh = (1.0f - g_ct * g_ct) // gradCt = inGradHt * rt * gradTanh - rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, gradCt, nullptr); // gradCt = rt * gradTanh - inGradHt.applyPairwiseTransform(pairwise::Multiply, gradCt, gradCt, nullptr); // gradCt = inGradHt * rt * gradTanh + rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, *gradCt); // gradCt = rt * gradTanh + inGradHt.applyPairwiseTransform(pairwise::Multiply, *gradCt, *gradCt); // gradCt = inGradHt * rt * gradTanh // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; - gradCt->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = (gradCt + inGradCt) - ct_1->applyPairwiseTransform(pairwise::Subtract, &zt, temp2, nullptr); // temp2 = (ct_1 - zt) - temp1->applyPairwiseTransform(pairwise::Multiply, ftMinus, temp1, nullptr); // temp1 = (gradCt + inGradCt)*(1-ft) - temp1->applyPairwiseTransform(pairwise::Multiply, &ft, temp1, nullptr); // temp1 = (gradCt + inGradCt)*(1-ft)*ft - temp1->applyPairwiseTransform(pairwise::Multiply, temp2, &gradBFt, nullptr); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; + gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = (gradCt + inGradCt) + ct_1->applyPairwiseTransform(pairwise::Subtract, zt, *temp2); // temp2 = (ct_1 - zt) + temp1->applyPairwiseTransform(pairwise::Multiply, *ftMinus, *temp1); // temp1 = (gradCt + inGradCt)*(1-ft) + temp1->applyPairwiseTransform(pairwise::Multiply, ft, *temp1); // temp1 = (gradCt + inGradCt)*(1-ft)*ft + temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, gradBFt); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt); - inGradHt.applyPairwiseTransform(pairwise::Multiply, rtMinus, &gradHXt, nullptr); + inGradHt.applyPairwiseTransform(pairwise::Multiply, *rtMinus, gradHXt); // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); - rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, temp1, nullptr); // temp1 = rt * grad_tanh - inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, temp1, nullptr); // temp1 = inGradHt * rt * grad_tanh - temp1->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = inGradHt * rt * grad_tanh + inGradCt - temp1->applyPairwiseTransform(pairwise::Multiply, ftMinus, &gradUZt, nullptr); // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); + rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, *temp1); // temp1 = rt * grad_tanh + inGradHt.applyPairwiseTransform(pairwise::Multiply, *temp1, *temp1); // temp1 = inGradHt * rt * grad_tanh + temp1->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = inGradHt * rt * grad_tanh + inGradCt + temp1->applyPairwiseTransform(pairwise::Multiply, *ftMinus, gradUZt); // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); gradUFt.assign(&gradBFt); gradURt.assign(&gradBRt); // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft; - gradCt->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = (gradCt + inGradCt) - temp1->applyPairwiseTransform(pairwise::Multiply, &ft, inGradCt, nullptr); // inGradCt = (gradCt + inGradCt) * ft; + gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = (gradCt + inGradCt) + temp1->applyPairwiseTransform(pairwise::Multiply, ft, *inGradCt); // inGradCt = (gradCt + inGradCt) * ft; if(t != 0) delete ct_1; @@ -283,9 +283,9 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) { // gradX auto weightsT = w->transpose(); // [K x 3K] MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N] - gradX->applyPairwiseTransform(pairwise::Add, gradHX, gradX, nullptr); // + grad_highway_x + gradX->applyPairwiseTransform(pairwise::Add, *gradHX, *gradX); // + grad_highway_x if(applyMask) - gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask + gradX->applyBroadcast(broadcast::Multiply, {0,1}, *mask, *gradX); // apply mask // gradB auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0,2}, false, true); // [1 x 2K] @@ -296,7 +296,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) { MmulHelper::mmul(gradU, x, gradW, 1., 0.); // [bS x 3K x K] delete gct; delete gradU; delete gradHX; - delete temp1; delete temp2; delete temp3; delete gradCt; delete wi; + delete temp1; delete temp2; delete gradCt; delete wi; delete gradTanh; delete ftMinus; delete rtMinus; delete gradBias; return Status::OK(); @@ -941,7 +941,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask // // gradB -// gradBias.reduceAlongDimension(reduce::Sum, gradB, {0,2}, false, true); // [1 x 2K] +// gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0,2}, false, true); // [1 x 2K] // // gradW [bS x 3K x inSize] // x->permutei({0, 2, 1}); // [bS x time x inSize] diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index 0bc80fa91..1d76138f2 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -173,13 +173,6 @@ namespace nd4j { order = shape::order(inp); e = 0; } - -// //Special case: empty.reshape(-1) -> return empty -// if (INPUT_VARIABLE(0)->isEmpty()) { -// // -// auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp)); -// return SHAPELIST(newShape); -// } std::vector shapeNew; @@ -226,11 +219,25 @@ namespace nd4j { //REQUIRE_TRUE(y->lengthOf() == 1 && y->e(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); auto shapeOf = y->getBufferAsVector(); Nd4jLong prod = 1; - for (auto v:shapeOf) + bool hasNegs = false; + for (auto v:shapeOf) { + if (v < 0) { + hasNegs = true; + v = 0; + } + prod *= v; + } REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well"); + // if there are -1s - we turn them into zeros + if (hasNegs) { + for (int e = 0; e < shapeOf.size(); e++) + if (shapeOf[e] < 0) + shapeOf[e] = 0; + } + auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data()); return SHAPELIST(CONSTANT(newShape)); } diff --git a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp index 2cc454438..878c4c0a3 100644 --- a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp @@ -26,16 +26,16 @@ namespace nd4j { namespace ops { CUSTOM_OP_IMPL(tile_to_shape, 1, 1, true, 0, -1) { - + auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - + std::vector outShape(block.getIArguments()->begin(), block.getIArguments()->end()); if (block.isInplace()) { - input->tileToShape(outShape); + input->tileToShape(outShape, *input); } else { - input->tileToShape(outShape, output); + input->tileToShape(outShape, *output); } return Status::OK(); @@ -44,7 +44,7 @@ namespace ops { DECLARE_SHAPE_FN(tile_to_shape) { auto in = inputShape->at(0); - // output shape always equals to arguments + // output shape always equals to arguments auto conv = ArrayUtils::toLongVector(*block.getIArguments()); @@ -73,9 +73,9 @@ namespace ops { auto gradX = OUTPUT_VARIABLE(0); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(input->shapeInfo(), epsNext->shapeInfo()); - // FIX ME: reduceAlongDims should have a signature with result pass to avoid assigning twice + // FIX ME: reduceAlongDimension should have a signature with result pass to avoid assigning twice if (!axisX.empty()) { - auto tempRes = epsNext->reduceAlongDims(reduce::Sum, axisX); + auto tempRes = epsNext->reduceAlongDimension(reduce::Sum, axisX); gradX->assign(tempRes); } else gradX->assign(epsNext); diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaConsumer.java b/libnd4j/include/ops/declarable/generic/strings/split_string.cpp similarity index 50% rename from nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaConsumer.java rename to libnd4j/include/ops/declarable/generic/strings/split_string.cpp index 7e73e4716..4af4e3aac 100644 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaConsumer.java +++ b/libnd4j/include/ops/declarable/generic/strings/split_string.cpp @@ -14,32 +14,37 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.camel.kafka; +// +// @author raver119@gmail.com +// -import lombok.AllArgsConstructor; -import lombok.Builder; -import org.apache.camel.CamelContext; -import org.apache.camel.ConsumerTemplate; -import org.nd4j.linalg.api.ndarray.INDArray; +#include +#if NOT_EXCLUDED(OP_split_string) -/** - * Created by agibsonccc on 7/19/16. - */ -@AllArgsConstructor -@Builder -public class Nd4jKafkaConsumer { - private KafkaConnectionInformation connectionInformation; - private ConsumerTemplate consumerTemplate; - private CamelContext camelContext; +#include - /** - * Receive an ndarray - * @return - */ - public INDArray receive() { - if (consumerTemplate == null) - consumerTemplate = camelContext.createConsumerTemplate(); - return consumerTemplate.receiveBody("direct:receive", INDArray.class); +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(split_string, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + return Status::OK(); + }; + + DECLARE_SHAPE_FN(split_string) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + return SHAPELIST(); + } + + DECLARE_TYPES(split_string) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_STRINGS}) + ->setAllowedOutputTypes({ALL_STRINGS}); + } } - } + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp b/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp index f11d760b7..7c257b903 100644 --- a/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp @@ -35,8 +35,8 @@ namespace nd4j { auto xO = OUTPUT_VARIABLE(0); auto yO = OUTPUT_VARIABLE(1); - x->applyScalar(scalar::Add, 1.0, xO, nullptr); - y->applyScalar(scalar::Add, 2.0, yO, nullptr); + x->applyScalar(scalar::Add, 1.0, *xO); + y->applyScalar(scalar::Add, 2.0, *yO); STORE_2_RESULTS(*xO, *yO); diff --git a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp index 8ade17504..21164f520 100644 --- a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp +++ b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp @@ -63,11 +63,11 @@ namespace nd4j { sparse2dense.insert(pair); } - std::unique_ptr rows(x->allTensorsAlongDimension({1})); + ResultSet rows = x->allTensorsAlongDimension({1}); //PRAGMA_OMP_PARALLEL_FOR for (int r = 0; r < batchSize; r++) { - auto row = rows->at(r); + auto row = rows.at(r); for (int e = 0; e < numColumns; e += 2) { int idx = row->e(e); diff --git a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp index 45060ad43..e7cd1ccb9 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp @@ -104,34 +104,34 @@ namespace nd4j { } nd4j::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, output, dims, exclusive, reverse); - std::unique_ptr val(output->dup()); + NDArray val = NDArray(output->dup()); - gradOut->applyPairwiseTransform(pairwise::Multiply, output, val.get(), nullptr); - val->applyPairwiseTransform(pairwise::Divide, input, val.get(), nullptr); + gradOut->applyPairwiseTransform(pairwise::Multiply, *output, val); + val.applyPairwiseTransform(pairwise::Divide, *input, val); if (!exclusive && !reverse) { if (dims.size()) - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, dims, true, false); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, false); else - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, false, true); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, false, true); } else if (!exclusive && reverse){ if (dims.size()) - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, dims, false, false); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, false, false); else - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, false, false); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, false, false); } else if (exclusive && !reverse) { if (dims.size()) - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, dims, true, true); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, true); else - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, true, true); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, true, true); } else { if (dims.size()) - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, dims, true, false); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, false); else - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, true, false); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, true, false); } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/floor.cpp b/libnd4j/include/ops/declarable/generic/transforms/floor.cpp index f89494fd1..5a8559075 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/floor.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/floor.cpp @@ -29,7 +29,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(transform::Floor, z, nullptr); + first->applyTransform(transform::Floor, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp b/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp index 644d15a58..123001dda 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp @@ -27,7 +27,7 @@ namespace nd4j { namespace ops { - REDUCTION_OP_IMPL(hashcode, 1, 1, false, 0, 0) { + CUSTOM_OP_IMPL(hashcode, 1, 1, false, 0, 0) { REQUIRE_TRUE(block.width() == 1, 0, "hashcode: this op can't be applied along dimension"); auto input = INPUT_VARIABLE(0); @@ -40,6 +40,10 @@ namespace nd4j { return Status::OK(); }; + DECLARE_SHAPE_FN(hashcode) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64)); + } + DECLARE_TYPES(hashcode) { getOpDescriptor() diff --git a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp index 06656b9de..7afe9b3ed 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp @@ -56,7 +56,7 @@ namespace ops { standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); // output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, output); - output->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain); + output->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, *gain, *output); if(bias != nullptr) { // output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), bias, output); // output->applyBroadcast(nd4j::broadcast::Add, {dimC}, bias); @@ -93,8 +93,8 @@ namespace ops { if(bias != nullptr) { REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); - // eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true); - eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); + // eps->reduceAlongDimension(nd4j::reduce::Sum, *dLdb, {0}, true); + eps->reduceAlongDimension(nd4j::reduce::Sum, *dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); } NDArray standardized(input->shapeInfo(), false, block.launchContext()); @@ -106,18 +106,17 @@ namespace ops { std::vector bargs = {}; standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); - standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr); - standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); + standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, *eps, standardized); + standardized.reduceAlongDimension(nd4j::reduce::Sum, *dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); nd4j::ops::standardize_bp standardizeBp; // eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx); - eps->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain, dLdx); + eps->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, *gain, *dLdx); auto dLdx_tmp = dLdx->dup(); - std::vector standardizeBpArgs = {input, dLdx_tmp}; + std::vector standardizeBpArgs = {input, &dLdx_tmp}; std::vector standardizeBpOut = {dLdx}; standardizeBp.execute(standardizeBpArgs, standardizeBpOut, targs, longAxis, bargs); - delete dLdx_tmp; return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp b/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp index 3d45bcf42..ef9bdb925 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp @@ -29,10 +29,10 @@ namespace nd4j { auto x = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - x->applyTransform(transform::Log1p, z, nullptr); + x->applyTransform(transform::Log1p, *z); STORE_RESULT(z); - + return Status::OK(); } DECLARE_SYN(log1p, Log1p); diff --git a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp index fac8451a5..2de8ee5a2 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp @@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(mirror_pad, 2, 1, false, 0, 1) { DECLARE_TYPES(mirror_pad) { getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32}); // to conform with TF + getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); // to conform with TF getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp index 9d410a6c3..c6c8c8ff8 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp @@ -78,7 +78,7 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) { DECLARE_TYPES(pad) { getOpDescriptor() ->setAllowedInputTypes(0, nd4j::DataType::ANY) - ->setAllowedInputTypes(1, {DataType::INT32}) // INT32 with TF + ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF // ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes ->setSameMode(true); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp b/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp index 67ef3aa24..25efc1a73 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp @@ -29,28 +29,28 @@ namespace nd4j { namespace ops { CONFIGURABLE_OP_IMPL(standardize, 1, 1, true, 0, -2) { - + auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - - std::vector axis; - if (block.width() > 1) + std::vector axis; + + if (block.width() > 1) axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) - axis = *block.getIArguments(); + else if (block.numI() > 0) + axis = *block.getIArguments(); REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") shape::checkDimensions(input->rankOf(), axis); - auto means = input->reduceAlongDims(reduce::Mean, axis, true); - auto stdev = input->varianceAlongDims(variance::SummaryStatsStandardDeviation, false, axis); + auto means = input->reduceAlongDimension(reduce::Mean, axis, true); + auto stdev = input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, axis); stdev.reshapei(means.getShapeAsVector()); - input->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), &means, output, false); - output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), &stdev, output, false); - output->applyScalar(nd4j::scalar::ReplaceNans, 0, output, nullptr); + input->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), means, *output, false); + output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), stdev, *output, false); + output->applyScalar(nd4j::scalar::ReplaceNans, 0, *output); return Status::OK(); } @@ -69,9 +69,9 @@ namespace ops { auto output = OUTPUT_VARIABLE(0); std::vector axis; - if (block.width() == 3) + if (block.width() == 3) axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) + else if (block.numI() > 0) axis = *block.getIArguments(); REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") @@ -80,13 +80,13 @@ namespace ops { shape::checkDimensions(input->rankOf(), axis); auto longAxis = ArrayUtils::toLongVector(axis); - auto means = input->reduceAlongDims(reduce::Mean, axis, true); - auto stdev = input->varianceAlongDims(variance::SummaryStatsStandardDeviation, false, axis); + auto means = input->reduceAlongDimension(reduce::Mean, axis, true); + auto stdev = input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, axis); stdev.reshapei(means.getShapeAsVector()); - eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), &stdev, output, false); + eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), stdev, *output, false); - auto dldu_sum = -output->reduceAlongDims(reduce::Sum, axis, true); + NDArray dldu_sum = -output->reduceAlongDimension(reduce::Sum, axis, true); NDArray dldx_u(input->shapeInfo(), false, block.launchContext()); std::vector meanBpArgs = {input, &dldu_sum}; @@ -100,12 +100,12 @@ namespace ops { // (eps * (means - input) / (stdev * stdev)) NDArray tmp(eps->shapeInfo(), false, block.launchContext()); - means.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), input, &tmp, false); - tmp.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &tmp, nullptr); - stdev.applyPairwiseTransform(nd4j::pairwise::Multiply, &stdev, &stdev, nullptr); - tmp.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), &stdev, &tmp, false); + means.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), *input, tmp, false); + tmp.applyPairwiseTransform(nd4j::pairwise::Multiply, *eps, tmp); + stdev.applyPairwiseTransform(nd4j::pairwise::Multiply, stdev, stdev); + tmp.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), stdev, tmp, false); - auto dlds_sum = tmp.reduceAlongDims(reduce::Sum, axis, true); + auto dlds_sum = tmp.reduceAlongDimension(reduce::Sum, axis, true); NDArray dldx_s(input->shapeInfo(), false, block.launchContext()); std::vector stdevBpArgs = {input, &dlds_sum}; std::vector stdevBpOutput = {&dldx_s}; @@ -115,7 +115,7 @@ namespace ops { stdevBp.execute(stdevBpArgs, stdevBpOutput, stdevBpTArgs, longAxis, stdevBpBArgs); *output += dldx_s; - output->applyScalar(nd4j::scalar::ReplaceNans, 0, output, nullptr); + output->applyScalar(nd4j::scalar::ReplaceNans, 0, *output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/tri.cpp b/libnd4j/include/ops/declarable/generic/transforms/tri.cpp index 727d42ba5..a6106f197 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/tri.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/tri.cpp @@ -32,8 +32,8 @@ CUSTOM_OP_IMPL(tri, -2, 1, false, 0, 1) { const int diag = block.numI() > 2 ? INT_ARG(2) : 0; - BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (1., diag + 1, 0, 'l'), LIBND4J_TYPES); // fill with unities lower triangular block of matrix - BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (0., 0, diag, 'u'), LIBND4J_TYPES); // fill with zeros upper triangular block of matrix + BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (1., diag + 1, 0, *output, 'l'), LIBND4J_TYPES); // fill with unities lower triangular block of matrix + BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (0., 0, diag, *output, 'u'), LIBND4J_TYPES); // fill with zeros upper triangular block of matrix // output->setValueInDiagMatrix(1., diag, 'l'); // output->setValueInDiagMatrix(0., diag+1, 'u'); diff --git a/libnd4j/include/ops/declarable/generic/transforms/triu.cpp b/libnd4j/include/ops/declarable/generic/transforms/triu.cpp index 1c4214e9b..b382cbfb1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/triu.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/triu.cpp @@ -36,7 +36,7 @@ CUSTOM_OP_IMPL(triu, 1, 1, false, 0, 0) { const int diag = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; - BUILD_SINGLE_SELECTOR(input->dataType(), input->fillAsTriangular, (0, diag, 0, 'l', output), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), input->fillAsTriangular, (0, diag, 0, *output, 'l' ), LIBND4J_TYPES); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp b/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp new file mode 100644 index 000000000..6b1514ab9 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_print_affinity) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(print_affinity, 1, 1, true, 0, 0) { + // TODO: make this op compatible with ArrayList etc + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + nd4j_printf(": Actuality: [HOST: %s; DEVICE: %s]; affinity: [%i]; Pointers: [HOST: %p; DEVICE: %p]; DataBuffer length: %lld\n", block.nodeId(), input->isActualOnHostSide() ? "true" : "false", input->isActualOnDeviceSide() ? "true" : "false", input->dataBuffer()->deviceId(), input->getBuffer(), input->getSpecialBuffer(), input->dataBuffer()->getLenInBytes()); + + return Status::OK(); + } + + DECLARE_TYPES(print_affinity) { + getOpDescriptor() + ->setAllowedInputTypes(0, nd4j::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_STRINGS}) + ->setAllowedOutputTypes(0, nd4j::DataType::INT32); + } + + DECLARE_SHAPE_FN(print_affinity) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32)); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/util/print_variable.cpp b/libnd4j/include/ops/declarable/generic/util/print_variable.cpp new file mode 100644 index 000000000..6828b2f90 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/util/print_variable.cpp @@ -0,0 +1,77 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// +#include +#if NOT_EXCLUDED(OP_print_variable) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(print_variable, 1, 1, true, 0, 0) { + // TODO: make this op compatible with ArrayList etc + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + std::string str; + + if (block.width() == 2) { + auto message = INPUT_VARIABLE(1); + REQUIRE_TRUE(message->isS(), 0, "print_variable: message variable must be a String"); + + str = message->e(0); + } + + bool printSpecial = false; + if (block.numB() > 0) + printSpecial = B_ARG(0); + + if (printSpecial && !nd4j::Environment::getInstance()->isCPU()) { + // only specific backends support special printout. for cpu-based backends it's the same as regular print + + if (block.width() == 2) + helpers::print_special(*block.launchContext(), *input, str); + else + helpers::print_special(*block.launchContext(), *input); + } else { + // optionally add message to the print out + if (block.width() == 2) { + input->printIndexedBuffer(str.c_str()); + } else { + input->printIndexedBuffer(); + } + } + + return Status::OK(); + } + + DECLARE_TYPES(print_variable) { + getOpDescriptor() + ->setAllowedInputTypes(0, nd4j::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_STRINGS}) + ->setAllowedOutputTypes(0, nd4j::DataType::INT32); + } + + DECLARE_SHAPE_FN(print_variable) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32)); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/broadcastable.h b/libnd4j/include/ops/declarable/headers/broadcastable.h index 7ee53b52a..9a2dc9f62 100644 --- a/libnd4j/include/ops/declarable/headers/broadcastable.h +++ b/libnd4j/include/ops/declarable/headers/broadcastable.h @@ -356,6 +356,7 @@ namespace nd4j { */ #if NOT_EXCLUDED(OP_Pow) DECLARE_BROADCASTABLE_OP(Pow, 0, 0); + DECLARE_CUSTOM_OP(Pow_bp, 3, 2, false, 0, 0); #endif /** diff --git a/libnd4j/include/ops/declarable/headers/compat.h b/libnd4j/include/ops/declarable/headers/compat.h new file mode 100644 index 000000000..8ce73153e --- /dev/null +++ b/libnd4j/include/ops/declarable/headers/compat.h @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SAMEDIFF_COMPAT_H +#define SAMEDIFF_COMPAT_H + +#include + +namespace nd4j { + namespace ops { + /** + * This operation splits input string into pieces separated by delimiter + * PLEASE NOTE: This implementation is compatible with TF 1.x + * + * Input[0] - string to split + * Input[1] - delimiter + * + * Returns: + * Output[0] - indices tensor + * Output[1] - values tensor + */ + #if NOT_EXCLUDED(OP_compat_string_split) + DECLARE_CUSTOM_OP(compat_string_split, 2, 2, false, 0, 0); + #endif + + /** + * This operation converts TF sparse array representation to dense NDArray + */ + #if NOT_EXCLUDED(OP_compat_sparse_to_dense) + DECLARE_CUSTOM_OP(compat_sparse_to_dense, 4, 1, false, 0, 0); + #endif + + } +} + + +#endif //SAMEDIFF_COMPAT_H diff --git a/libnd4j/include/ops/declarable/headers/images.h b/libnd4j/include/ops/declarable/headers/images.h new file mode 100644 index 000000000..14acd1877 --- /dev/null +++ b/libnd4j/include/ops/declarable/headers/images.h @@ -0,0 +1,115 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// +// +// @author AbdelRauf (rauf@konduit.ai) +// + +#ifndef LIBND4J_HEADERS_IMAGES_H +#define LIBND4J_HEADERS_IMAGES_H + +#include +#include +#include +#include +#include + +namespace nd4j { +namespace ops { + + +/** + * Rgb To Hsv + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_rgb_to_hsv) + DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, true, 0, 0); +#endif + +/** + * Hsv To Rgb + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_hsv_to_rgb) + DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, true, 0, 0); +#endif + +/** +* Rgb To GrayScale +* Input arrays: +* 0 - input array with rank >= 1, the RGB tensor to convert. Last dimension must have size 3 and should contain RGB values. +*/ +#if NOT_EXCLUDED(OP_rgb_to_grs) + DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0); +#endif + + /** + * Rgb To Yuv + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_rgb_to_yuv) + DECLARE_CONFIGURABLE_OP(rgb_to_yuv, 1, 1, true, 0, 0); +#endif + + /** + * Yuv To Rgb + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_rgb_to_yuv) + DECLARE_CONFIGURABLE_OP(yuv_to_rgb, 1, 1, true, 0, 0); + +/** +* Rgb To Yiq +* Input arrays: +* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. +* Int arguments: +* 0 - optional argument, corresponds to dimension with 3 channels +*/ +#if NOT_EXCLUDED(OP_rgb_to_yiq) + DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, true, 0, 0); +#endif + +/** +* Yiq To Rgb +* Input arrays: +* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. +* Int arguments: +* 0 - optional argument, corresponds to dimension with 3 channels +*/ +#if NOT_EXCLUDED(OP_yiq_to_rgb) + DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, true, 0, 0); +#endif + +} +} + +#endif +#endif diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index e56ba9d6e..791027baa 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2019 Konduit K.K. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -162,8 +162,24 @@ namespace nd4j { * Input : batched tensor with rank >=2 * Output: tensor with rank lesser by 1 from input */ + #if NOT_EXCLUDED(OP_matrix_diag_part) DECLARE_CUSTOM_OP(matrix_diag_part, 1, 1, false, 0, 0); + #endif + /** + * QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper triangular. + * For A (MxN) Q is M x M and R is (NxN). + * + * Input : + * 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of float matricies + * + * Output: + * 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies {Qs} + * 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular matricies {Rs} + */ + #if NOT_EXCLUDED(OP_qr) + DECLARE_CUSTOM_OP(qr, 1, 2, false, 0, 0); + #endif /** * This operation takes 2 arrays: original values, and values to be excluded. And returns 2 arrays: values left after exclusion, and indices in original array for surivals. @@ -527,6 +543,20 @@ namespace nd4j { DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); #endif + /** + * This op calculates lgamma function lgamma(x) = log(Gamma(x)) + * + * Input arrays: + * 0: x - input matrix + * + * Output array: + * 0: log of Gamma(x) + * + */ + #if NOT_EXCLUDED(OP_lgamma) + DECLARE_OP(lgamma, 1, 1, true); + #endif + /** * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) * @@ -1027,6 +1057,43 @@ namespace nd4j { DECLARE_OP(matrix_inverse, 1, 1, true); #endif + /** + * triangular_solve op. - reverse Gaussian method for solve systems of linear equations. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations + * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations + * + * boolean args: + * 0 - lower - default is true (optional) - left part is lower triangular matrix + * 1 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ + #if NOT_EXCLUDED(OP_triangular_solve) + DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0); + #endif + + /** + * lu op. - make LUP decomposition of given batch of 2D square matricies + * + * input params: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M matricies in it + * 1 - int (32 or 64) batched vector of permutations with length M - shape (x * y * z * ::: * M) + * + * int argument: + * 0 - data type of output permutaion vector (int32 or int64), optional, default INT32 + */ + + #if NOT_EXCLUDED(OP_matrix_inverse) + DECLARE_CUSTOM_OP(lu, 1, 2, false, 0, 0); + #endif + /** * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] * @@ -1690,6 +1757,27 @@ namespace nd4j { DECLARE_CUSTOM_OP(resize_bicubic, 1, 1, false, 0, -2); #endif + /** + * This op make area interpolated resize (as OpenCV INTER_AREA algorithm) for given tensor + * + * input array: + * 0 - images - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - size - 1D-Tensor with 2 values (newWidth, newHeight) (if missing a pair of integer args should be provided). + * + * int args: - proveded only when size tensor is missing + * 0 - new height + * 1 - new width + * boolean args: + * 0 - align_corners - optional (default is false) + * + * output array: + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) + * + */ + #if NOT_EXCLUDED(OP_resize_area) + DECLARE_CUSTOM_OP(resize_area, 1, 1, false, 0, -2); + #endif + /** * This op make interpolated resize for given tensor with given algorithm. * Supported algorithms are bilinear, bicubic, nearest_neighbor. diff --git a/libnd4j/include/ops/declarable/headers/random.h b/libnd4j/include/ops/declarable/headers/random.h index a361c8fde..f52534411 100644 --- a/libnd4j/include/ops/declarable/headers/random.h +++ b/libnd4j/include/ops/declarable/headers/random.h @@ -49,6 +49,22 @@ namespace nd4j { #if NOT_EXCLUDED(OP_randomuniform) DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0); #endif + /* + * multinomial (categorical) random generator draws samples from a multinomial distribution + * + * Input array: + * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size (N), num_classes (K)] + * 1 - array with one int value of samples number, number of independent samples to draw for each experiment 1,N. + * Int arguments: + * 0 - optional argument, corresponds to dimension with batch_size + * 1 - optional argument, integer type to use for the output. Default int64. + * + * Output array: + * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] + */ + #if NOT_EXCLUDED(OP_random_multinomial) + DECLARE_CUSTOM_OP(random_multinomial, 2, 1, false, 0, 0); + #endif #if NOT_EXCLUDED(OP_random_normal) DECLARE_CUSTOM_OP(random_normal, 1, 1, true, 2, 0); diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketKeyListener.java b/libnd4j/include/ops/declarable/headers/strings.h old mode 100755 new mode 100644 similarity index 60% rename from deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketKeyListener.java rename to libnd4j/include/ops/declarable/headers/strings.h index 95c356c78..0849f118a --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketKeyListener.java +++ b/libnd4j/include/ops/declarable/headers/strings.h @@ -14,25 +14,29 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.aws.s3.reader; +// +// @author raver119@gmail.com +// -import com.amazonaws.services.s3.AmazonS3; +#ifndef SAMEDIFF_STRINGS_H +#define SAMEDIFF_STRINGS_H -/** - * When paginating through a result applyTransformToDestination, - * allows the user to react to a bucket result being found - * @author Adam Gibson - * - */ -public interface BucketKeyListener { - - /** - * - * @param s3 an s3 client - * @param bucket the bucket being iterated on - * @param key the current key - */ - void onKey(AmazonS3 s3, String bucket, String key); +#include +namespace nd4j { + namespace ops { + /** + * This operation splits input string into pieces separated by delimiter + * + * Input[0] - string to split + * Input[1] - delimiter + */ + #if NOT_EXCLUDED(OP_split_string) + DECLARE_CUSTOM_OP(split_string, 2, 1, true, 0, 0); + #endif + } } + + +#endif //SAMEDIFF_STRINGS_H diff --git a/libnd4j/include/ops/declarable/headers/transforms.h b/libnd4j/include/ops/declarable/headers/transforms.h index 6c82aa19f..ab4e962a3 100644 --- a/libnd4j/include/ops/declarable/headers/transforms.h +++ b/libnd4j/include/ops/declarable/headers/transforms.h @@ -213,7 +213,7 @@ namespace nd4j { * This operation calculates hash code, optionally along dimension */ #if NOT_EXCLUDED(OP_hashcode) - DECLARE_REDUCTION_OP(hashcode, 1, 1, false, 0, 0); + DECLARE_CUSTOM_OP(hashcode, 1, 1, false, 0, 0); #endif /** diff --git a/datavec/datavec-camel/src/test/java/org/datavec/camel/component/ListStringInputMarshaller.java b/libnd4j/include/ops/declarable/headers/util.h similarity index 55% rename from datavec/datavec-camel/src/test/java/org/datavec/camel/component/ListStringInputMarshaller.java rename to libnd4j/include/ops/declarable/headers/util.h index 77b376d90..aa1f52363 100644 --- a/datavec/datavec-camel/src/test/java/org/datavec/camel/component/ListStringInputMarshaller.java +++ b/libnd4j/include/ops/declarable/headers/util.h @@ -14,28 +14,31 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.datavec.camel.component; +// +// @author raver119@gmail.com +// -import org.apache.camel.Exchange; -import org.datavec.api.split.InputSplit; -import org.datavec.api.split.ListStringSplit; +#ifndef LIBND4J_UTILS_H +#define LIBND4J_UTILS_H -import java.util.List; +#include -/** - * Marshals List> - * - * @author Adam Gibson - */ -public class ListStringInputMarshaller implements DataVecMarshaller { - /** - * @param exchange - * @return - */ - @Override - public InputSplit getSplit(Exchange exchange) { - List> data = (List>) exchange.getIn().getBody(); - InputSplit listSplit = new ListStringSplit(data); - return listSplit; +namespace nd4j { + namespace ops { + /** + * This operation prints out NDArray content, either on host or device. + */ + #if NOT_EXCLUDED(OP_print_variable) + DECLARE_CUSTOM_OP(print_variable, 1, 1, true, 0, 0); + #endif + + /** + * This operation prints out affinity & locality status of given NDArray + */ + #if NOT_EXCLUDED(OP_print_affinity) + DECLARE_CUSTOM_OP(print_affinity, 1, 1, true, 0, 0); + #endif } } + +#endif //LIBND4J_UTILS_H diff --git a/libnd4j/include/ops/declarable/helpers/adjust_hue.h b/libnd4j/include/ops/declarable/helpers/adjust_hue.h index 3ccdfdd60..afa7b2436 100644 --- a/libnd4j/include/ops/declarable/helpers/adjust_hue.h +++ b/libnd4j/include/ops/declarable/helpers/adjust_hue.h @@ -17,6 +17,7 @@ // // @author raver119@gmail.com // @author Yurii Shyrma (iuriish@yahoo.com) +// @author Oleh Semeniv (oleg.semeniv@gmail.com) // #include @@ -41,33 +42,33 @@ FORCEINLINE _CUDA_HD void rgbToHsv(const T& r, const T& g, const T& b, T& h, T& const T max = nd4j::math::nd4j_max(r, nd4j::math::nd4j_max(g, b)); const T min = nd4j::math::nd4j_min(r, nd4j::math::nd4j_min(g, b)); const T c = max - min; - + const T _p6 = (T)1 / (T)6; // calculate h if(c == 0) { h = 0; } else if(max == r) { - h = 60.f * ((g - b) / c) + (g >= b ? 0 : 360); + h = _p6 * ((g - b) / c) + (g >= b ? (T)0 : (T)1); } else if(max == g) { - h = 60.f * ((b - r) / c) + 120; + h = _p6 * ((b - r) / c + (T)2); } else { // max == b - h = 60.f * ((r - g) / c) + 240; + h = _p6 * ((r - g) / c + (T)4); } // calculate s s = max == (T)0 ? (T)0 : c / max; // calculate v - v = max / 255.f; + v = max;// / 255.f; } //////////////////////////////////////////////////////////////////////////////// template FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T& g, T& b) { - const float sector = h / 60.f; + const float sector = h * 6.f; const T c = v * s; if(0.f <= sector && sector < 1.f) { @@ -101,9 +102,25 @@ FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T& b = v - c * (sector - 5); } - r *= 255; - g *= 255; - b *= 255; +// r *= 255; +// g *= 255; +// b *= 255; +} + +//////////////////////////////////////////////////////////////////////////////// +template +FORCEINLINE _CUDA_HD void rgbYuv(const T& r, const T& g, const T& b, T& y, T& u, T& v) { + y = static_cast(0.299) * r + static_cast(0.587) *g + static_cast(0.114) * b; + u = -static_cast(0.14714119) * r - static_cast(0.2888691) * g + static_cast(0.43601035) * b; + v = static_cast(0.61497538) * r - static_cast(0.51496512) * g - static_cast(0.10001026) * b; +} + +//////////////////////////////////////////////////////////////////////////////// +template +FORCEINLINE _CUDA_HD void yuvRgb(const T& y, const T& u, const T& v, T& r, T& g, T& b) { + r = y + static_cast(1.13988303) * v; + g = y - static_cast(0.394642334) * u - static_cast(0.58062185) * v; + b = y + static_cast(2.03206185) * u; } /*//////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/convolutions.h b/libnd4j/include/ops/declarable/helpers/convolutions.h index 68b39cfd5..e8bf735bc 100644 --- a/libnd4j/include/ops/declarable/helpers/convolutions.h +++ b/libnd4j/include/ops/declarable/helpers/convolutions.h @@ -41,8 +41,10 @@ namespace nd4j { static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) { if(paddingMode == 0) { // valid - oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; - oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; + // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; + // oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; + oH = (iH - ((kH - 1) * dH + 1) + 2 * pH) / sH + 1; + oW = (iW - ((kW - 1) * dW + 1) + 2 * pW) / sW + 1; } else if (paddingMode == 1) { // same oH = (int) math::nd4j_ceil(iH * 1. / sH); @@ -57,9 +59,9 @@ namespace nd4j { static inline void calcOutSizePool3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int paddingMode) { if(paddingMode == 0) { // valid - oD = (iD - (kD + (kD - 1) * (dD - 1)) + 2 * pD) / sD + 1; - oH = (iH - (kH + (kH - 1) * (dH - 1)) + 2 * pH) / sH + 1; - oW = (iW - (kW + (kW - 1) * (dW - 1)) + 2 * pW) / sW + 1; + oD = (iD - ((kD - 1) * dD + 1) + 2 * pD) / sD + 1; + oH = (iH - ((kH - 1) * dH + 1) + 2 * pH) / sH + 1; + oW = (iW - ((kW - 1) * dW + 1) + 2 * pW) / sW + 1; } else if(paddingMode == 1) { // same oD = (int) nd4j::math::nd4j_ceil(iD * 1. / sD); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp b/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp index a7123d42f..f8704d7b0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp @@ -195,7 +195,7 @@ namespace helpers { return res; }; - input->applyTriplewiseLambda(gradX, epsilon, gainsInternal, output); + input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); } void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp b/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp index ba0f36eb5..9a11baf37 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp @@ -115,9 +115,9 @@ static void softMaxForVector_(void *input, Nd4jLong *inShapeInfo, void *output, BUILD_SINGLE_SELECTOR(input.dataType(), _softMaxDerivForVector, (context, input.getBuffer(), input.getShapeInfo(), output.buffer()), FLOAT_TYPES); } else { - auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); + (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily + auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= sumAlongDim; output *= (1.f - output); // derivative } @@ -204,7 +204,7 @@ static void softmax_(nd4j::LaunchContext * context, const NDArray& input, NDArra else output = 1.; } - else if(input.isSameShapeStrict(&output)) { + else if(input.isSameShapeStrict(output)) { TadPack tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimension); Nd4jLong* tadShapeInfo = tadPack.primaryShapeInfo(); @@ -275,10 +275,10 @@ static void softmax_(nd4j::LaunchContext * context, const NDArray& input, NDArra } } else { - NDArray max = input.reduceAlongDims(nd4j::reduce::Max, {dimension}, true); - input.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), &max, &output, false); - output.applyTransform(nd4j::transform::Exp); - NDArray sum = output.reduceAlongDims(nd4j::reduce::Sum, {dimension}, true); + NDArray max = input.reduceAlongDimension(nd4j::reduce::Max, {dimension}, true); + input.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), max, output, false); + output.applyTransform(nd4j::transform::Exp, output); + NDArray sum = output.reduceAlongDimension(nd4j::reduce::Sum, {dimension}, true); output /= sum; } } @@ -347,7 +347,7 @@ void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& auto routine = LAMBDA_T(_x, threshold) { return _x > (T)threshold? _x: (T)0.f; }; - const_cast(input).applyLambda(routine, &output); + const_cast(input).applyLambda(routine, output); } void thresholdRelu(nd4j::LaunchContext * context, NDArray const& input, double threshold, NDArray& output) { @@ -358,7 +358,7 @@ void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& static void thresholdReluDerivative_(nd4j::LaunchContext * context, NDArray* input, double theta, NDArray* dLdO, NDArray* output) { auto derivative = LAMBDA_TT(_x, grO, theta) {if (_x > theta) return grO; else return static_cast(0); }; - input->applyPairwiseLambda(dLdO, derivative, output); + input->applyPairwiseLambda(*dLdO, derivative, *output); } @@ -381,11 +381,11 @@ void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& } else { - auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); + (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily + auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= sumAlongDim; - output.applyTransform(transform::Log); + output.applyTransform(transform::Log, output); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp index a64864b1b..a910a854c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp @@ -83,15 +83,28 @@ static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output, const Nd4jLong xStrideH = isNCHW ? input.stridesOf()[2] : input.stridesOf()[1]; const Nd4jLong xStrideW = isNCHW ? input.stridesOf()[3] : input.stridesOf()[2]; - auto func = PRAGMA_THREADS_FOR_3D { - for (uint b = start_x; b < stop_x; b += inc_x) - for (uint c = start_y; c < stop_y; c += inc_y) - for (uint h = start_z; h < stop_z; h += inc_z) - for (uint w = 0; w < oW; ++w) - z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + static_cast(y[c * yStrideC]); - }; + if (isNCHW) { - samediff::Threads::parallel_for(func, 0, bS, 1, 0, C, 1, 0, oH, 1); + auto func = PRAGMA_THREADS_FOR_3D { + for (uint b = start_x; b < stop_x; b += inc_x) + for (uint c = start_y; c < stop_y; c += inc_y) + for (uint h = start_z; h < stop_z; h += inc_z) + for (uint w = 0; w < oW; ++w) + z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + static_cast(y[c * yStrideC]); + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, C, 1, 0, oH, 1); + } else { + auto func = PRAGMA_THREADS_FOR_3D { + for (uint b = start_x; b < stop_x; b++) + for (uint h = start_y; h < stop_y; h++) + for (uint w = start_z; w < stop_z; w++) + for (uint c = 0; c < C; c++) + z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + y[c * yStrideC]; + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1, 0, oW, 1); + } } } else if(output.rankOf() == 5) { @@ -141,7 +154,7 @@ static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output, } else { const int channelDim = isNCHW ? 1 : input.rankOf() - 1; // second or last - const_cast(input).applyBroadcast(nd4j::broadcast::Add, {channelDim}, &bias, &output); + const_cast(input).applyBroadcast(nd4j::broadcast::Add, {channelDim}, bias, output); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp b/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp index ae76f0289..978c037fa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp @@ -45,11 +45,11 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr rgbToHsv(x[i], x[i + 1], x[i + 2], h, s, v); - h += delta * 360; - if (h > 360) - h -= 360; + h += delta ; + if (h > (T)1) + h -= (T)1; else if (h < 0) - h += 360; + h += (T)1; hsvToRgb(h, s, v, z[i], z[i + 1], z[i + 2]); } @@ -76,11 +76,11 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); - h += delta * 360; - if (h > 360) - h -= 360; + h += delta ; + if (h > (T)1) + h -= (T)1; else if (h < 0) - h += 360; + h += (T)1; hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp index 7a0d8b97b..ad2e29a97 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp @@ -15,7 +15,7 @@ ******************************************************************************/ // -// @author Yurii Shyrma, created on 25.02.2018 +// @author Yurii Shyrma (iuriish@yahoo.com) // @@ -31,112 +31,160 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////// template -static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { +static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, + NDArray* output, + const std::vector& axes, const double epsilon) { // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - NDArray sigmaInvGam(mean); // do not copy mean's buffer, take only its shapeInfo - T eps = epsilon; + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + const T* m = mean->bufferAsT(); + const T* v = variance->bufferAsT(); + const T* g = gamma == nullptr ? nullptr : gamma->bufferAsT(); + const T* b = beta == nullptr ? nullptr : beta->bufferAsT(); - if(gamma != nullptr) { - auto lambda = LAMBDA_TT(x, y, eps) {return x / nd4j::math::nd4j_sqrt(y + eps);}; - const_cast(gamma)->applyPairwiseLambda(variance, lambda, &sigmaInvGam); - } - else { - auto lambda = LAMBDA_T(x, eps) { return 1. / nd4j::math::nd4j_sqrt(x + eps); }; - const_cast(variance)->applyLambda(lambda, &sigmaInvGam); - } + const bool xzSameOffset = shape::haveSameShapeAndStrides(input->getShapeInfo(), output->getShapeInfo()); - // auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt); // sigmaInvGam = 1 / sqrt(variance + epsilon) - // if(gamma != nullptr) sigmaInvGam *= *gamma; - - const T* sigmaBuff = sigmaInvGam.bufferAsT(); - const T* meanBuff = mean->bufferAsT(); - const T* inBuff = input->bufferAsT(); - T* outBuff = output->bufferAsT(); + bool paramSameOffset = shape::haveSameShapeAndStrides(mean->getShapeInfo(), variance->getShapeInfo()); + if(paramSameOffset && gamma != nullptr) + paramSameOffset &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gamma->getShapeInfo()); + if(paramSameOffset && beta != nullptr) + paramSameOffset &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), beta->getShapeInfo()); const Nd4jLong lenBig = input->lengthOf(); const Nd4jLong lenSmall = mean->lengthOf(); - const Nd4jLong* inShapeInfo = input->getShapeInfo(); - const Nd4jLong* meanShapeInfo = mean->getShapeInfo(); - uint inShapeInfoCast[MAX_RANK]; - uint meanShapeInfoCast[MAX_RANK]; - bool canCastIn = nd4j::DataTypeUtils::castShapeInfo(inShapeInfo, inShapeInfoCast); - bool canCastMean = nd4j::DataTypeUtils::castShapeInfo(meanShapeInfo, meanShapeInfoCast); - - const Nd4jLong step = lenBig / lenSmall; + const Nd4jLong steps = lenBig / lenSmall; std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes); OmpLaunchHelper info(lenBig, lenSmall); - if(beta != nullptr) { - const T* betaBuff = beta->bufferAsT(); - auto func = PRAGMA_THREADS_DO { - const auto threadNum = thread_id; - Nd4jLong* inOffsets = new Nd4jLong[step]; - Nd4jLong* memBuff = new Nd4jLong[2 * inShapeInfo[0]]; + auto func = PRAGMA_THREADS_DO { - for (int j = 0; j < lenSmall; ++j) { + Nd4jLong* xOffsets = new Nd4jLong[steps]; + Nd4jLong* zOffsets = xzSameOffset ? xOffsets : new Nd4jLong[steps]; + Nd4jLong* auxBuff = new Nd4jLong[2 * input->rankOf()]; - const bool isOwner = j < info._numThreads ? threadNum == j : threadNum == j % info._numThreads; - if (!isOwner) continue; + for (int j = 0; j < lenSmall; ++j) { - const Nd4jLong start = j * step; - const Nd4jLong end = start + step; + const bool isOwner = (j < info._numThreads) ? thread_id == j : thread_id == (j % info._numThreads); - // calculate offset for mean, variance, gamma, beta (all of them have the same shape) - auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, canCastMean); - // calculate offset for input and output (all of them have the same shape) - shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, memBuff, dimsToExclude.data()); + if(!isOwner) + continue; - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < step; ++i) { - auto offsetBig = inOffsets[i]; - outBuff[offsetBig] = (inBuff[offsetBig] - meanBuff[offsetSmall]) * sigmaBuff[offsetSmall] + betaBuff[offsetSmall]; + const auto meanOffset = shape::getIndexOffset(j, mean->getShapeInfo()); + const auto varOffset = paramSameOffset ? meanOffset : shape::getIndexOffset(j, variance->getShapeInfo()); + + const auto meanVal = m[meanOffset]; + auto sigmaInvGam = static_cast(1) / nd4j::math::nd4j_sqrt(v[varOffset] + epsilon); + + if(g != nullptr) { + const auto gammaOffset = paramSameOffset ? meanOffset : shape::getIndexOffset(j, gamma->getShapeInfo()); + sigmaInvGam *= g[gammaOffset]; + } + + T betaVal = static_cast(0); + if(b != nullptr) { + const auto betaOffset = paramSameOffset ? meanOffset : shape::getIndexOffset(j, beta->getShapeInfo()); + betaVal = b[betaOffset]; + } + + // calculate offsets for input and output + shape::outerArrayOffsets(xOffsets, j, input->getShapeInfo(), mean->getShapeInfo(), auxBuff, dimsToExclude.data()); + if(!xzSameOffset) + shape::outerArrayOffsets(zOffsets, j, output->getShapeInfo(), mean->getShapeInfo(), auxBuff, dimsToExclude.data()); + + PRAGMA_OMP_SIMD + for (uint i = 0; i < steps; ++i) + z[zOffsets[i]] = (x[xOffsets[i]] - meanVal) * sigmaInvGam + betaVal; + } + + delete []auxBuff; + delete []xOffsets; + if(!xzSameOffset) + delete []zOffsets; + }; + + samediff::Threads::parallel_do(func, info._numThreads); +} + +////////////////////////////////////////////////////////////////////////// +template +static void batchnorm2_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, + NDArray* output, + const std::vector& axes, const double epsilon) { + + // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta + + const auto x = input->bufferAsT(); + auto z = output->bufferAsT(); + const auto m = mean->bufferAsT(); + const auto v = variance->bufferAsT(); + const auto g = gamma == nullptr ? nullptr : gamma->bufferAsT(); + const auto b = beta == nullptr ? nullptr : beta->bufferAsT(); + + // xRank == zRank, minRank = meanRank = varianceRank = gammaRank = betaRank + const uint xRank = input->rankOf(); + const uint minRank = mean->rankOf(); + const uint numAxes = axes.size(); + + const bool xzSameOffset = shape::haveSameShapeAndStrides(input->getShapeInfo(), output->getShapeInfo()); + + bool paramSameOffset = shape::haveSameShapeAndStrides(mean->getShapeInfo(), variance->getShapeInfo()); + if(paramSameOffset && gamma != nullptr) + paramSameOffset &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gamma->getShapeInfo()); + if(paramSameOffset && beta != nullptr) + paramSameOffset &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), beta->getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + + Nd4jLong coords[MAX_RANK]; + + for (auto i = start; i < stop; i += increment) { + + shape::index2coords(i, input->getShapeInfo(), coords); + + const auto xOffset = shape::getOffset(input->getShapeInfo(), coords); + const auto zOffset = xzSameOffset ? xOffset : shape::getOffset(output->getShapeInfo(), coords); + + if(minRank == xRank) { + for (uint i = 0, j = 0; i < xRank; ++i) { + if(j < numAxes && i != axes[j]) + coords[i] = 0; + else + ++j; } } - delete []inOffsets; - delete []memBuff; - }; + else // minRank = numAxes = 1 in this case + coords[0] = coords[axes[0]]; - samediff::Threads::parallel_do(func, info._numThreads); - } - else { - auto func = PRAGMA_THREADS_DO { - const auto threadNum = thread_id; - Nd4jLong* inOffsets = new Nd4jLong[step]; - Nd4jLong* memBuff = new Nd4jLong[2 * inShapeInfo[0]]; + const auto meanOffset = shape::getOffset(mean->getShapeInfo(), coords); + const auto varianceOffset = paramSameOffset ? meanOffset : shape::getOffset(variance->getShapeInfo(), coords); - for (int j = 0; j < lenSmall; ++j) { - const bool isOwner = j < info._numThreads ? threadNum == j : threadNum == j % info._numThreads; - if (!isOwner) continue; + T sigmaInvGam = 1. / nd4j::math::nd4j_sqrt(v[varianceOffset] + epsilon); - const Nd4jLong start = j * step; - const Nd4jLong end = start + step; - - // calculate offset for mean, variance, gamma, beta (all of them have the same shape) - auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, canCastMean); - // calculate offset for input and output (all of them have the same shape) - shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, memBuff, dimsToExclude.data()); - - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < step; ++i) { - auto offsetBig = inOffsets[i]; - outBuff[offsetBig] = (inBuff[offsetBig] - meanBuff[offsetSmall]) * sigmaBuff[offsetSmall]; - } + if(g != nullptr) { + const auto gammaOffset = paramSameOffset ? meanOffset : shape::getOffset(gamma->getShapeInfo(), coords); + sigmaInvGam *= g[gammaOffset]; } - delete []inOffsets; - delete []memBuff; - }; - samediff::Threads::parallel_do(func, info._numThreads); - } + z[zOffset] = (x[xOffset] - m[meanOffset]) * sigmaInvGam; + + if(b != nullptr) { + const auto betaOffset = paramSameOffset ? meanOffset : shape::getOffset(beta->getShapeInfo(), coords); + z[zOffset] += b[betaOffset]; + } + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf()); } ////////////////////////////////////////////////////////////////////////// void batchnorm(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { + // batchnorm2_ is slower BUILD_SINGLE_SELECTOR(input->dataType(), batchnorm_, (input, mean, variance, gamma, beta, output, axes, epsilon), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp index 5f7fbf694..e5e51d38f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp @@ -56,7 +56,7 @@ namespace nd4j { sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); } - nd4j_printf("Sum: %lld\n", sumt) + //nd4j_printf("Sum: %lld\n", sumt) output = (sumt > -1); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp b/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp index e2d24c591..4f8989caf 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp @@ -28,7 +28,7 @@ namespace helpers { template void _confusionFunctor(NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - std::unique_ptr arrs(output->allTensorsAlongDimension({1})); + ResultSet arrs = output->allTensorsAlongDimension({1}); int lLen = labels->lengthOf(); auto func = PRAGMA_THREADS_FOR { @@ -36,7 +36,7 @@ namespace helpers { auto label = labels->e(j); auto pred = predictions->e(j); T value = (weights == nullptr ? (T) 1.0f : weights->e(j)); - (*arrs->at(label)).p(pred, value); + arrs.at(label)->p(pred, value); } }; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index 47938e9fb..db09f0d3c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -373,7 +373,7 @@ namespace nd4j { NDArray* gradBR = gradB; if(gradB->rankOf() == 2) gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradBR, gradOaxesForDot); // sum over bS, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW if(gradBR != gradB) delete gradBR; } @@ -506,7 +506,7 @@ namespace nd4j { NDArray* gradBR = gradB; if(gradB->rankOf() == 2) gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW if(gradBR != gradB) delete gradBR; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp b/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp index 3150c0cfd..0adb0e249 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp @@ -36,23 +36,19 @@ void crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray auto tadsB = _b.allTensorsAlongDimension({1}); auto tadsO = _o.allTensorsAlongDimension({1}); - int tads = tadsA->size(); + int tads = tadsA.size(); auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - auto a_ = tadsA->at(e); - auto b_ = tadsB->at(e); - auto o_ = tadsO->at(e); + auto a_ = tadsA.at(e); + auto b_ = tadsB.at(e); + auto o_ = tadsO.at(e); helpers::cross(context, a_, b_, o_); } }; samediff::Threads::parallel_tad(func, 0, tads); - - delete tadsA; - delete tadsB; - delete tadsO; } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp index 073167f18..281e6c809 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp @@ -34,7 +34,7 @@ namespace nd4j { for (int i = sourceDimsLen; i > 0; i--) sourceDims[sourceDimsLen - i] = input->rankOf() - i; - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(sourceDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(sourceDims); unsigned int outSize = outputList.size(); @@ -48,15 +48,14 @@ namespace nd4j { for (int k = 1; k < r; k++) outDims[k - 1] = k; - std::unique_ptr listOutForCurrent( - outputs[i].first->allTensorsAlongDimension(outDims)); + ResultSet listOutForCurrent = outputs[i].first->allTensorsAlongDimension(outDims); outputs[i].second = 0; //PRAGMA_OMP_PARALLEL_FOR_IF(indices->lengthOf() > Environment::getInstance()->elementwiseThreshold()) for (int e = 0; e < indices->lengthOf(); ++e) if ((*indices).e(e) == i) - listOutForCurrent->at(outputs[i].second++)->assign(listOfTensors->at(e)); + listOutForCurrent.at(outputs[i].second++)->assign(listOfTensors.at(e)); } } else { @@ -104,7 +103,7 @@ namespace nd4j { for (int i = restDims.size(); i > 0; i--) restDims[restDims.size() - i] = output->rankOf() - i; - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (int e = 0; e < numOfData; e++) { auto data = inputs[e]; @@ -113,7 +112,7 @@ namespace nd4j { for (int i = sourceDims.size(); i > 0; i--) sourceDims[sourceDims.size() - i] = data->rankOf() - i; - std::unique_ptr listOfTensors(data->allTensorsAlongDimension(sourceDims)); + ResultSet listOfTensors = data->allTensorsAlongDimension(sourceDims) ; for (int i = 0; i < index->lengthOf(); i++) { auto pos = index->e(i); @@ -127,7 +126,7 @@ namespace nd4j { return ND4J_STATUS_VALIDATION; } - listOfOutTensors->at(pos)->assign(listOfTensors->at(i)); + listOfOutTensors.at(pos)->assign(listOfTensors.at(i)); } } } @@ -145,7 +144,7 @@ namespace nd4j { for (int i = sourceDimsLen; i > 0; i--) sourceDims[sourceDimsLen - i] = input->rankOf() - i; - std::unique_ptr listOfTensors(outputList[0]->allTensorsAlongDimension(sourceDims)); + ResultSet listOfTensors = outputList[0]->allTensorsAlongDimension(sourceDims); for (unsigned int i = 0; i < inputGradientList.size(); i++) { outputs[i].first = inputGradientList[i]; @@ -155,14 +154,13 @@ namespace nd4j { for (int k = 1; k < outputs[i].first->rankOf(); k++) outDims[k - 1] = k; - std::unique_ptr listOutForCurrent( - outputs[i].first->allTensorsAlongDimension(outDims)); + ResultSet listOutForCurrent = outputs[i].first->allTensorsAlongDimension(outDims); outputs[i].second = 0; for (int e = 0; e < indices->lengthOf(); ++e) if (indices->e(e) == i) - listOfTensors->at(e)->assign(listOutForCurrent->at(outputs[i].second++)); + listOfTensors.at(e)->assign(listOutForCurrent.at(outputs[i].second++)); } } else { // one-dimensional case diff --git a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp index f3fe89103..0a46c995e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp @@ -28,14 +28,14 @@ namespace helpers { template static void _extractPatches(NDArray* images, NDArray* output, int sizeRow, int sizeCol, int strideRow, int strideCol, int rateRow, int rateCol, bool theSame){ std::vector restDims({1, 2, 3}); // the first and the last dims - std::unique_ptr listOfMatricies(images->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutputs(output->allTensorsAlongDimension(restDims)); + ResultSet listOfMatricies = images->allTensorsAlongDimension(restDims); + ResultSet listOfOutputs = output->allTensorsAlongDimension(restDims); // 3D matricies - 2D matricies of vectors (if last dim is greater than 1) //int e = 0; const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1); const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1); const int ksize = ksizeRowsEffective * ksizeColsEffective; - int batchCount = listOfMatricies->size(); //lengthOf() / ksize; + int batchCount = listOfMatricies.size(); //lengthOf() / ksize; Nd4jLong lastDim = images->sizeAt(3); Nd4jLong outLastDim = output->sizeAt(3); Nd4jLong rowDim = images->sizeAt(1); @@ -51,8 +51,8 @@ namespace helpers { auto func = PRAGMA_THREADS_FOR { for (auto batch = 0; batch < stop; batch += increment) { - auto patch = listOfMatricies->at(batch); - auto outMatrix = listOfOutputs->at(batch); + auto patch = listOfMatricies.at(batch); + auto outMatrix = listOfOutputs.at(batch); for (Nd4jLong i = 0; i < outRowDim; i++) { for (Nd4jLong j = 0; j < outColDim; j++) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp index 81f00b066..f18f48fac 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp @@ -105,7 +105,7 @@ namespace helpers { return (nd4j::math::nd4j_floor(val / scale + T(0.5f)) * scale + nudgedMin); }; - input->applyLambda(fakeQuantizationWithMinMax, output); + input->applyLambda(fakeQuantizationWithMinMax, *output); } void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp index ecc4cf24a..f6756dd88 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp @@ -30,7 +30,7 @@ static void applyGradientDescent_(NDArray* input, NDArray* step, double weight, return _x - (_y * weight); }; - input->applyPairwiseLambda(step, lambda, output); + input->applyPairwiseLambda(*step, lambda, *output); } void applyGradientDescent(nd4j::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp index 579ab2612..3db5e5373 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp @@ -77,15 +77,15 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa // reset gate r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r->applyTransform(transform::Sigmoid); + r->applyTransform(transform::Sigmoid, *r); // update gate u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u->applyTransform(transform::Sigmoid); + u->applyTransform(transform::Sigmoid, *u); // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c->applyTransform(transform::Tanh); + c->applyTransform(transform::Tanh, *c); NDArray temp = 1.f - *c * *c; @@ -231,15 +231,15 @@ void gruCellBP(nd4j::LaunchContext* context, // reset gate NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid); + r.applyTransform(transform::Sigmoid, r); // update gate NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid); + u.applyTransform(transform::Sigmoid, u); // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh); + c.applyTransform(transform::Tanh, c); // h = (1 - u) * c + u * hPrev @@ -352,10 +352,10 @@ void gruCellBP(nd4j::LaunchContext* context, dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0})); // [nU] + dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] + dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU] + dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] } // ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp index 97cd2f84e..911230367 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp @@ -29,26 +29,16 @@ namespace nd4j { auto result = reinterpret_cast(zBuffer); int length = shape::length(xShapeInfo); - // FIXME: 2??? - int _threads = 2; - - int span = (length / _threads) + 8; X binSize = (max_val - min_val) / (numBins); - PRAGMA_OMP_PARALLEL_THREADS(_threads) + // FIXME: this op should be parallelized { - int tid, start, end; - int *bins = new int[numBins]; std::memset(bins, 0, sizeof(int) * numBins); - tid = omp_get_thread_num(); - start = span * tid; - end = span * (tid + 1); - if (end > length) end = length; PRAGMA_OMP_SIMD - for (int x = start; x < end; x++) { + for (int x = 0; x < length; x++) { int idx = (int) ((dx[x] - min_val) / binSize); if (idx < 0) idx = 0; @@ -58,15 +48,12 @@ namespace nd4j { bins[idx]++; } - PRAGMA_OMP_CRITICAL - { - PRAGMA_OMP_SIMD - for (int x = 0; x < numBins; x++) { - result[x] += bins[x]; - } - + PRAGMA_OMP_SIMD + for (int x = 0; x < numBins; x++) { + result[x] += bins[x]; } + delete[] bins; } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index b17167b9a..d4089359f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2019 Konduit K.K. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -35,6 +35,8 @@ limitations under the License. #include #include +#include +#include "../cross.h" namespace nd4j { namespace ops { @@ -55,8 +57,9 @@ namespace helpers { : inSize / static_cast(outSize); } - struct ImageResizerState { - explicit ImageResizerState(bool alignCorners, bool halfPixelCenters) + template + struct ImageResizerStateCommon { + explicit ImageResizerStateCommon(bool alignCorners, bool halfPixelCenters) : _alignCorners(alignCorners), _halfPixelCenters(halfPixelCenters) {} @@ -94,14 +97,14 @@ namespace helpers { return validateAndCalculateOutputSize(input, width, height); } - Nd4jLong batchSize; - Nd4jLong outHeight; - Nd4jLong outWidth; - Nd4jLong inHeight; - Nd4jLong inWidth; - Nd4jLong channels; - float heightScale; - float widthScale; + I batchSize; + I outHeight; + I outWidth; + I inHeight; + I inWidth; + I channels; + F heightScale; + F widthScale; NDArray* output = nullptr; private: @@ -109,6 +112,8 @@ namespace helpers { bool _halfPixelCenters; }; + typedef ImageResizerStateCommon ImageResizerState; + // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the // floating point coordinates of the top,left pixel is 0.5,0.5. struct HalfPixelScaler { @@ -255,7 +260,7 @@ namespace helpers { // Handle no-op resizes efficiently. if (outHeight == inHeight && outWidth == inWidth) { output->assign(images); - return ND4J_STATUS_OK; + return Status::OK(); } std::vector ys(outHeight + 1); @@ -283,7 +288,7 @@ namespace helpers { samediff::Threads::parallel_for(func, 0, xsSize); resizeImage_(images->getDataBuffer()->primaryAsT(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT()); - return ND4J_STATUS_OK; + return Status::OK(); } template @@ -353,6 +358,7 @@ namespace helpers { int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); + return Status::OK(); } int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height, @@ -682,27 +688,29 @@ namespace helpers { pY2[pt_index], pY3[pt_index]); } - template + template static void bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) { std::vector xWais(resizerState.outWidth); - computeXWeightsAndIndices(resizerState, halfPixelCenters, &xWais); const auto numChannels = resizerState.channels; const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; + const auto batchNum = resizerState.batchSize; + const auto outHeight = resizerState.outHeight; + const auto outWidth = resizerState.outWidth; - const T* inputPtr = image->getDataBuffer()->primaryAsT(); - float* pOutputY = output->dataBuffer()->primaryAsT(); // output is float anyway - std::vector cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0); + auto func = PRAGMA_THREADS_FOR { + const T* inputPtr = image->getDataBuffer()->primaryAsT(); + F* pOutputY = output->dataBuffer()->primaryAsT(); // output is float anyway + std::vector cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0); - auto func = PRAGMA_THREADS_FOR { for (auto b = start; b < stop; ++b) { auto pInput = inputPtr + b * inBatchWidth; - for (auto y = 0; y < resizerState.outHeight; ++y) { - auto pOutput = &pOutputY[(b * resizerState.outHeight + y) * resizerState.outWidth * numChannels]; + for (auto y = 0; y < outHeight; ++y) { + auto pOutput = &pOutputY[(b * outHeight + y) * outWidth * numChannels]; WeightsAndIndices yWai; if (halfPixelCenters) { @@ -713,16 +721,16 @@ namespace helpers { resizerState.heightScale, y, resizerState.inHeight, &yWai); } // Make pointers represent offsets of data in inputBPtr. - const T *y_ptr_0 = pInput + yWai._index0 * inRowWidth; - const T *y_ptr_1 = pInput + yWai._index1 * inRowWidth; - const T *y_ptr_2 = pInput + yWai._index2 * inRowWidth; - const T *y_ptr_3 = pInput + yWai._index3 * inRowWidth; + const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; + const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; + const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; + const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; if (numChannels == 3) { // Manually unroll case of 3 channels. - float cached_value_0[4] = {0}; - float cached_value_1[4] = {0}; - float cached_value_2[4] = {0}; + F cached_value_0[4] = {0}; + F cached_value_1[4] = {0}; + F cached_value_2[4] = {0}; for (auto x = 0; x < resizerState.outWidth; ++x) { const WeightsAndIndices &xWai = xWais[x]; // Shift values in cached_value_* to fill first '_advance' values. @@ -854,7 +862,7 @@ namespace helpers { } for (auto c = 0; c < numChannels; ++c) { pOutput[x * numChannels + c] = - compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, + (F)compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, xWai._weight2, xWai._weight3); } } @@ -862,7 +870,7 @@ namespace helpers { } } }; - samediff::Threads::parallel_tad(func, 0, resizerState.batchSize); + samediff::Threads::parallel_tad(func, 0, batchNum); } // simplified bicubic resize without antialiasing @@ -873,7 +881,7 @@ namespace helpers { ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align int res = st.validateAndCreateOutput(image, width, height); if (res == Status::OK()) - bicubicInterpolateWithCaching(image, st, halfPixelAlign, output); + bicubicInterpolateWithCaching(image, st, halfPixelAlign, output); return res; } @@ -881,6 +889,206 @@ namespace helpers { bool const alignCorners, bool const halfPixelAlign, NDArray* output) { BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES); } +// ------------------------------------------------------------------------------------------------------------------ // + struct CachedInterpolation { + Nd4jLong start; + Nd4jLong end; + float startScale; + float endMinusOneScale; + bool needsBounding; + }; + + template + struct ScaleCache { + float yScale; + T const* yPtr; + }; + // Computes the sum of all x values defined by taken across + // the y offsets and scales defined by y_ptrs and y_scales, for channel c. + // + // Note that is a template parameter to avoid a performance + // penalty from dynamically checking it. + template + static void computePatchSumOf3Channels(float scale, + ImageResizerState const& st, + std::vector> const& yPtrs, + CachedInterpolation const& xCache, + float* outputPtr) { + + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + float sum_0 = 0; + float sum_1 = 0; + float sum_2 = 0; + for (int i = 0; i < yPtrs.size(); ++i) { + const T* ptr = yPtrs[i].yPtr; + float scaleX = xCache.startScale; + Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth); + float sum_y_0 = static_cast(ptr[offset + 0]) * scaleX; + float sum_y_1 = static_cast(ptr[offset + 1]) * scaleX; + float sum_y_2 = static_cast(ptr[offset + 2]) * scaleX; + + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]); + sum_y_1 += static_cast(ptr[offset + 1]); + sum_y_2 += static_cast(ptr[offset + 2]); + } + scaleX = xCache.endMinusOneScale; + offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]) * scaleX; + sum_y_1 += static_cast(ptr[offset + 1]) * scaleX; + sum_y_2 += static_cast(ptr[offset + 2]) * scaleX; + } + sum_0 += sum_y_0 * yPtrs[i].yScale; + sum_1 += sum_y_1 * yPtrs[i].yScale; + sum_2 += sum_y_2 * yPtrs[i].yScale; + } + + outputPtr[0] = sum_0 * scale; + outputPtr[1] = sum_1 * scale; + outputPtr[2] = sum_2 * scale; + } + + // Computes the sum of all x values defined by taken across + // the y offsets and scales defined by y_ptrs and y_scales, for channel c. + // + // Note that is a template parameter to avoid a performance + // penalty from dynamically checking it. + template + static void computePatchSum(float scale, const ImageResizerState& st, + const std::vector>& yPtrs, + const CachedInterpolation& xCache, + float* outputPtr) { + + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + const auto numChannels = st.channels; + for (Nd4jLong c = 0; c < numChannels; ++c) { + float sum = 0; + for (int i = 0; i < yPtrs.size(); ++i) { + T const* ptr = yPtrs[i].yPtr; + float scaleX = xCache.startScale; + float sumY = static_cast(ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * scaleX; + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + sumY += static_cast( + ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]); + } + scaleX = xCache.endMinusOneScale; + sumY += static_cast(ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + c]) * scaleX; + } + sum += sumY * yPtrs[i].yScale; + } + outputPtr[c] = sum * scale; + } + } + + + + template + static void resizeArea(ImageResizerState const& st, std::vector const& caches, NDArray const* input, NDArray* output) { + T const* inputPtr = input->bufferAsT(); + float scale = 1.f / (st.heightScale * st.widthScale); + auto outputPtr = output->bufferAsT(); // output is always float. TO DO: provide another float types also with template declaration + + auto batchProcess = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch += increment) { + for (auto y = 0; y < st.outHeight; ++y) { + const float inY = y * st.heightScale; + const float inY1 = (y + 1) * st.heightScale; + // The start and end height indices of all the cells that could + // contribute to the target cell. + const Nd4jLong yStart = math::nd4j_floor(inY); + const Nd4jLong yEnd = math::nd4j_ceil(inY1); + + std::vector> yCaches; + auto cacheLen = yEnd - yStart; + if (cacheLen) { + yCaches.resize(cacheLen); + }; + + for (auto i = yStart, k = 0LL; i < yEnd; ++i, ++k) { + ScaleCache scaleCache; + if (i < inY) { + scaleCache.yScale = (i + 1 > inY1 ? st.heightScale : i + 1 - inY); + } else { + scaleCache.yScale = (i + 1 > inY1 ? inY1 - i : 1.0); + } + scaleCache.yPtr = inputPtr + (batch * st.inHeight * st.inWidth * st.channels + + bound(i, st.inHeight) * st.inWidth * st.channels); + yCaches[k] = scaleCache; + } + float* output = outputPtr + (batch * st.outHeight + y) * st.channels * st.outWidth; + + if (st.channels == 3) { + for (Nd4jLong x = 0; x < st.outWidth; ++x) { + const CachedInterpolation &xCache = caches[x]; + computePatchSumOf3Channels(scale, st, yCaches, xCache, output); + output += st.channels; + } + } else { + for (Nd4jLong x = 0; x < st.outWidth; ++x) { + const CachedInterpolation &xCache = caches[x]; + computePatchSum(scale, st, yCaches, xCache, output); + output += st.channels; + } + } + } + } + }; + samediff::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1); + } + + template + int resizeAreaFunctor_(nd4j::LaunchContext* context, NDArray const* image, int const width, int const height, + bool const alignCorners, NDArray* output) { + ImageResizerState st(alignCorners, false); // Create resize info + auto res = st.validateAndCalculateOutputSize(image, width, height); + if (Status::OK() == res) { + std::vector xCached(st.outWidth); + auto cachingProcedure = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; x += increment) { + auto &xCache = xCached[x]; + const float inX = x * st.widthScale; + const float inX1 = (x + 1) * st.widthScale; + + Nd4jLong v = math::nd4j_floor(inX); + xCache.start = v; + xCache.startScale = + v < inX ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v + : 1.f); + v = math::nd4j_ceil(inX1); + xCache.end = v--; + xCache.endMinusOneScale = + v < inX ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v + : 1.f); + xCache.needsBounding = bound(xCache.start, st.inWidth) != xCache.start || + bound(xCache.end - 1, st.inWidth) != (xCache.end - 1); + + } + }; + samediff::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1); + + resizeArea(st, xCached, image, output); + } + return res; + } + + int resizeAreaFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, + bool const alignCorners, NDArray* output) { + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, (context, image, width, height, alignCorners, output), NUMERIC_TYPES); + } + // ------------------------------------------------------------------------------------------------------------------ // int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { @@ -888,9 +1096,9 @@ namespace helpers { case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break; case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break; case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break; + case kResizeArea: return resizeAreaFunctor(context, image, width, height, preserveAspectRatio, output); case kResizeLanczos5: case kResizeGaussian: - case kResizeArea: case kResizeMitchelcubic: throw std::runtime_error("helper::resizeFunctor: Non implemented yet."); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp b/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp new file mode 100644 index 000000000..e065174d5 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp @@ -0,0 +1,288 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// @author AbdelRauf (rauf@konduit.ai) +// + +#include +#include +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +template +static void rgbToGrs_(const NDArray& input, NDArray& output, const int dimC) { + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + const int rank = input.rankOf(); + + if(dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && + 'c' == output.ordering() && 1 == output.ews()){ + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + const auto xStep = i*3; + z[i] = 0.2989f*x[xStep] + 0.5870f*x[xStep + 1] + 0.1140f*x[xStep + 2]; + } + }; + + samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); + return; + } + + auto func = PRAGMA_THREADS_FOR{ + + Nd4jLong coords[MAX_RANK]; + for (auto i = start; i < stop; i += increment) { + shape::index2coords(i, output.getShapeInfo(), coords); + const auto zOffset = shape::getOffset(output.getShapeInfo(), coords); + const auto xOffset0 = shape::getOffset(input.getShapeInfo(), coords); + const auto xOffset1 = xOffset0 + input.strideAt(dimC); + const auto xOffset2 = xOffset1 + input.strideAt(dimC); + z[zOffset] = 0.2989f*x[xOffset0] + 0.5870f*x[xOffset1] + 0.1140f*x[xOffset2]; + } + }; + + samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); + return; +} + +void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { + BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrs_, (input, output, dimC), NUMERIC_TYPES); +} + +template +FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, const int dimC, Op op) { + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + const int rank = input.rankOf(); + bool bSimple = (dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && + 'c' == output.ordering() && 1 == output.ews()); + + if (bSimple) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); + } + }; + + samediff::Threads::parallel_for(func, 0, input.lengthOf(), 3); + return; + } + + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimC); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input.stridesOf()[dimC]; + const Nd4jLong zDimCstride = output.stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + return; +} + +template +FORCEINLINE static void rgbYuv_(const NDArray& input, NDArray& output, const int dimC) { + auto op = nd4j::ops::helpers::rgbYuv; + return rgbToFromYuv_(input, output, dimC, op); +} + +void transformRgbYuv(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { + BUILD_SINGLE_SELECTOR(input.dataType(), rgbYuv_, (input, output, dimC), FLOAT_TYPES); +} + +template +FORCEINLINE static void yuvRgb_(const NDArray& input, NDArray& output, const int dimC) { + auto op = nd4j::ops::helpers::yuvRgb; + return rgbToFromYuv_(input, output, dimC, op); +} + +void transformYuvRgb(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { + BUILD_SINGLE_SELECTOR(input.dataType(), yuvRgb_, (input, output, dimC), FLOAT_TYPES); +} + +template +FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC, Op op) { + + const int rank = input->rankOf(); + + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } + else { + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } +} + + +template +FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC , T (&tr)[3][3] ) { + + const int rank = input->rankOf(); + + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + // TODO: Use tensordot or other optimizied helpers to see if we can get better performance. + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + //simple M*v //tr.T*v.T // v * tr //rule: (AB)' =B'A' + // v.shape (1,3) row vector + T x0, x1, x2; + x0 = x[i]; //just additional hint + x1 = x[i + 1]; + x2 = x[i + 2]; + z[i] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0]; + z[i+1] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1]; + z[i+2] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2]; + + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } + else { + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + //simple M*v //tr.T*v + T x0, x1, x2; + x0 = xTad[0]; + x1 = xTad[xDimCstride]; + x2 = xTad[2 * xDimCstride]; + zTad[0] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0]; + zTad[zDimCstride] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1]; + zTad[2 * zDimCstride] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2]; + + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } +} + + + + +template +FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, const int dimC) { + auto op = nd4j::ops::helpers::hsvToRgb; + return tripleTransformer(input, output, dimC, op); +} + +template +FORCEINLINE static void rgbHsv(const NDArray* input, NDArray* output, const int dimC) { + auto op = nd4j::ops::helpers::rgbToHsv; + return tripleTransformer(input, output, dimC, op); +} + + +template +FORCEINLINE static void rgbYiq(const NDArray* input, NDArray* output, const int dimC) { + T arr[3][3] = { + { (T)0.299, (T)0.59590059, (T)0.2115 }, + { (T)0.587, (T)-0.27455667, (T)-0.52273617 }, + { (T)0.114, (T)-0.32134392, (T)0.31119955 } + }; + return tripleTransformer(input, output, dimC, arr); +} + +template +FORCEINLINE static void yiqRgb(const NDArray* input, NDArray* output, const int dimC) { + //TODO: this operation does not use the clamp operation, so there is a possibility being out of range. + //Justify that it will not be out of range for images data + T arr[3][3] = { + { (T)1, (T)1, (T)1 }, + { (T)0.95598634, (T)-0.27201283, (T)-1.10674021 }, + { (T)0.6208248, (T)-0.64720424, (T)1.70423049 } + }; + return tripleTransformer(input, output, dimC, arr); +} + + + +void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), FLOAT_TYPES); +} + +void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), FLOAT_TYPES); +} + +void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (input, output, dimC), FLOAT_TYPES); +} + +void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (input, output, dimC), FLOAT_TYPES); +} + + +} +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp index 62f8316ce..4db975ddf 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp @@ -31,7 +31,7 @@ namespace helpers { return x > (T) 0.f ? y : T(0.f); }; - theFirst->applyPairwiseLambda(theSecond, functor, nullptr); + theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); } void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { @@ -46,7 +46,7 @@ namespace helpers { return x > zero ? y : zero; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); /* auto x = input->bufferAsT(); @@ -74,7 +74,7 @@ namespace helpers { return x > (T)0.f && x < (T)6.f? y : T(0.f); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -90,7 +90,7 @@ namespace helpers { return x < 0 ? alphaT * y : y; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { @@ -106,7 +106,7 @@ namespace helpers { return y * nd4j::math::nd4j_eluderivative(x, alphaT); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { @@ -119,7 +119,7 @@ namespace helpers { return y * simdOps::SELUDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -132,7 +132,7 @@ namespace helpers { return y * (3 * x * x); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -146,7 +146,7 @@ namespace helpers { return x > T(0.f)? y : -y; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -160,7 +160,7 @@ namespace helpers { return nd4j::math::nd4j_max(x, (T)0.f) - x * y + nd4j::math::nd4j_log((T)1.f + nd4j::math::nd4j_exp(-nd4j::math::nd4j_abs(x))); }; - logits->applyPairwiseLambda(labels, functor, output); + logits->applyPairwiseLambda(*labels, functor, *output); } void sigmCrossEntropy(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { @@ -178,7 +178,7 @@ namespace helpers { return static_cast(1.) - y - e / (static_cast(1.) + e); }; - logits->applyPairwiseLambda(labels, functor, output); + logits->applyPairwiseLambda(*labels, functor, *output); } void sigmCrossEntropyGrad(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { @@ -193,7 +193,7 @@ namespace helpers { return y * ((T)1.0f - (th * th)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void tanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -208,7 +208,7 @@ namespace helpers { return y * simdOps::HardTanhDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void hardTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -221,7 +221,7 @@ namespace helpers { return y * simdOps::RationalTanhDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void rationalTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -234,7 +234,7 @@ namespace helpers { return x > (T) 0.0f ? y * (nd4j::math::nd4j_tanhderivative(x)) : (T) 0.0f; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void rectifiedTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -251,7 +251,7 @@ namespace helpers { return y * ((T) 1.0f / (ss * ss)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void softSignDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -265,7 +265,7 @@ namespace helpers { return y * (p / (p + 1.)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void softPlusDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -282,7 +282,7 @@ namespace helpers { return y * (s * ((T) 1.0f - s)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void sigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -295,7 +295,7 @@ namespace helpers { return y * simdOps::HardSigmoidDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void hardSigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -305,24 +305,24 @@ namespace helpers { template static void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { // reduce along axis with - std::unique_ptr tempInput(input->dup()); - input->applyTransform(transform::Exp, tempInput.get()); + NDArray tempInput = input->dup(); + input->applyTransform(transform::Exp, tempInput); std::vector axisVector; if (axis != nullptr) { axisVector.resize(axis->lengthOf()); for (size_t i = 0; i < axisVector.size(); ++i) axisVector[i] = axis->e(i); } - tempInput->reduceAlongDimension(reduce::Sum, output, axisVector); - output->applyTransform(transform::Log, nullptr, nullptr); + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); } template static void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { // reduce along axis with - std::unique_ptr tempInput(input->dup()); - input->applyPairwiseTransform(pairwise::Subtract, subtrah, tempInput.get(), nullptr); - tempInput->applyTransform(transform::Exp, nullptr, nullptr); + NDArray tempInput = input->dup(); + input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); + tempInput.applyTransform(transform::Exp, tempInput); std::vector axisVector; if (axis != nullptr) { @@ -330,8 +330,8 @@ namespace helpers { for (size_t i = 0; i < axisVector.size(); ++i) axisVector[i] = axis->e(i); } - tempInput->reduceAlongDimension(reduce::Sum, output, axisVector); - output->applyTransform(transform::Log, nullptr, nullptr); + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); } void logSumExp(nd4j::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output) { @@ -364,16 +364,16 @@ static void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, NDArr if (weights->isScalar()) { - const_cast(input)->applyPairwiseLambda(const_cast(targets), mainRoutineT1, output); + const_cast(input)->applyPairwiseLambda(const_cast(*targets), mainRoutineT1, *output); } else { std::unique_ptr targetVector(new NDArray(*weights)); - targetVector->applyScalar(scalar::Add, -1.f); + targetVector->applyScalar(scalar::Add, -1.f, *targetVector); std::unique_ptr targetTensor(new NDArray(*targets)); *targetTensor = (*targetVector * *targetTensor) + T(1.f); - const_cast(input)->applyTriplewiseLambda(const_cast(targets), targetTensor.get(), mainRoutineT2, output); + const_cast(input)->applyTriplewiseLambda(const_cast(*targets), *targetTensor.get(), mainRoutineT2, *output); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp new file mode 100644 index 000000000..2978a9d45 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp @@ -0,0 +1,53 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +// calculate digamma function for array elements +template +static void lgamma_(NDArray& x, NDArray& z) { + + auto lgammaProc = LAMBDA_T(x_) { + return T(DataTypeUtils::fromT() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log(math::nd4j_gamma(x)); + }; + + x.applyLambda(lgammaProc, z); +} + +void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z) { + + BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES); + + + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp index 922fdc3a9..683a82392 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp @@ -86,7 +86,7 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h // if clipping value is provided then cell state is clipped by this value prior to the cell output activation if(clippingCellValue > 0.0) - ct->applyScalar(scalar::LstmClip, clippingCellValue); + ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); if(peephole) zot += (*ct) * (*Wc)({{2*nOut, 3*nOut}}); // add peephole connections to output gate zot + ct*Wc @@ -99,7 +99,7 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] // if clipping projection is provided then projected cell output state is clipped by this value if(clippingProjValue != 0.) - ht->applyScalar(scalar::LstmClip, clippingProjValue); + ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); } else ht->assign(&htNoPeepHole); @@ -199,13 +199,13 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast PRAGMA_OMP_SINGLE { PRAGMA_OMP_TASK - zz.applyTransform(transform::Tanh, z); //z = tanh(zz) + zz.applyTransform(transform::Tanh, *z); //z = tanh(zz) PRAGMA_OMP_TASK - zi.applyTransform(transform::Sigmoid, i); //i = sigmoid(zi) + zi.applyTransform(transform::Sigmoid, *i); //i = sigmoid(zi) PRAGMA_OMP_TASK - zf.applyTransform(transform::Sigmoid, f); //f = sigmoid(zf); + zf.applyTransform(transform::Sigmoid, *f); //f = sigmoid(zf); } if (z->ews() == 1 && i->ews() == 1 && c->ews() == 1 && cLast->ews() == 1 && f->ews() == 1 && h->ews() == 1 && @@ -214,15 +214,15 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast BUILD_SINGLE_SELECTOR(z->dataType(), fusedTanh, (z, i, c, cLast, f, h), FLOAT_TYPES); } else { //cell state = blockInput .* inputGate + prevCellState .* forgetGate - z->applyPairwiseTransform(pairwise::Multiply, i, c, nullptr); //c = z * i + z->applyPairwiseTransform(pairwise::Multiply, *i, *c); //c = z * i auto temp = (*f) * (*cLast); *c += temp; //c = (i * z) + (zf * (*cLast)) - c->applyTransform(transform::Tanh, h); //h = tanh(c) + c->applyTransform(transform::Tanh, *h); //h = tanh(c) } // if clipping value is provided then cell state is clipped by this value prior to the cell output activation if(clippingCellValue > 0.0) - c->applyScalar(scalar::LstmClip, clippingCellValue); + c->applyScalar(scalar::LstmClip, clippingCellValue, *c); // add peephole connections to output gate zot + ct*Wc if(peephole) { @@ -230,11 +230,11 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast zo += prod; } - zo.applyTransform(transform::Sigmoid, o); // o = sigmoid(zo) + zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) // current cell output = ot*tanh(ct) - c->applyTransform(transform::Tanh, h); //h = tanh(c) - o->applyPairwiseTransform(pairwise::Multiply, h, y, nullptr); //y = o * h + c->applyTransform(transform::Tanh, *h); //h = tanh(c) + o->applyPairwiseTransform(pairwise::Multiply, *h, *y); //y = o * h } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 76817078b..9c7cb1bfe 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -22,6 +22,8 @@ #include #include #include +#include +#include namespace nd4j { namespace ops { @@ -32,15 +34,30 @@ namespace helpers { if (theFirst != theSecond) for (int i = 0; i < matrix->columns(); i++) { - T e0 = matrix->e(theFirst, i); - T e1 = matrix->e(theSecond, i); - - matrix->p(theFirst, i, e1); - matrix->p(theSecond, i, e0); + math::nd4j_swap(matrix->t(theFirst, i), matrix->t(theSecond, i)); } } BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES); + template + static void swapRows(T* matrixBuf, Nd4jLong* matrixShape, Nd4jLong theFirst, Nd4jLong theSecond) { + if (theFirst != theSecond) { + auto n = shape::sizeAt(matrixShape, -1); + + auto loop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + Nd4jLong theFirstPos[] = {theFirst, i}; + Nd4jLong theSecondPos[] = {theSecond, i}; + auto theFirstIndex = shape::getOffset(matrixShape, theFirstPos, 0); + auto theSecondIndex = shape::getOffset(matrixShape, theSecondPos, 0); + math::nd4j_swap(matrixBuf[theFirstIndex], matrixBuf[theSecondIndex]); + } + }; + + samediff::Threads::parallel_tad(loop, 0, n, 1); + } + } + void swapRows(NDArray* matrix, int theFirst, int theSecond) { BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES); } @@ -106,7 +123,7 @@ namespace helpers { } - template + template static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) { const int rowNum = input->rows(); @@ -132,7 +149,7 @@ namespace helpers { } } - if( pivotValue > T(0.00001)) { + if( pivotValue > DataTypeUtils::min()) { swapRows(&compoundMatrix, pivot, i); swapRows(&permutationMatrix, pivot, i); if (pivot != i) @@ -155,14 +172,113 @@ namespace helpers { if (swapCount % 2) determinant = -determinant; if (compound != nullptr) compound->assign(compoundMatrix); - if (permutation != nullptr) - permutation->assign(permutationMatrix); + if (permutation != nullptr) { + auto permutaionVector = NDArrayFactory::create('c', {rowNum}, DataTypeUtils::fromT(), input->getContext()); + for (auto i = 0; i < rowNum; i++) { + for (auto j = 0; j < columnNum; j++) { + if (permutationMatrix.t(i, j) != 0) { + permutaionVector.template t(i) = j; + } + } + } + if (permutationMatrix.isSameShape(permutation)) + permutation->assign(permutationMatrix); + else if (permutation->isSameShape(permutaionVector)) { + permutation->assign(permutaionVector); + } + } return determinant; } - BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); + BUILD_DOUBLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); + /* + * lu decomposition with naive algorithm with partial pivoting + * */ + template + static I argmaxCol(I column, T* compoundBuffer, Nd4jLong* compoundShape) { + auto rowNum = shape::sizeAt(compoundShape, 0); + Nd4jLong xInitial[] = {column, column}; + auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); + auto maxValue = T(0); //nd4j::math::nd4j_abs(compoundBuffer[xInitialIndex]); + auto result = -1; + //auto loop = PRAGMA_THREADS_FOR { + auto start = column, stop = rowNum, increment = 1; + for (auto rowCounter = start; rowCounter < stop; rowCounter += increment) { + Nd4jLong xPos[] = {rowCounter, column}; + auto xIndex = shape::getOffset(compoundShape, xPos, 0); + if (nd4j::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { + maxValue = nd4j::math::nd4j_max(maxValue, nd4j::math::nd4j_abs(compoundBuffer[xIndex])); + result = rowCounter; + } + } + //}; + //samediff::Threads::parallel_for(loop, column, rowNum, 1); + return result; + } + template + void processColumns(int currentRow, int rowNum, T* compoundBuf, Nd4jLong* compoundShape) { + Nd4jLong xDiag[] = {currentRow, currentRow}; + auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); + auto loop = PRAGMA_THREADS_FOR { + for (int j = start; j < stop; j += increment) { + Nd4jLong xRow[] = {j, currentRow}; + auto rowIndex = shape::getOffset(compoundShape, xRow, 0); + compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t(i, i); + for (int k = currentRow + 1; k < rowNum; k++) { + Nd4jLong yRow[] = {j, k}; + Nd4jLong yCol[] = {currentRow, k}; + auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); + auto colIndex = shape::getOffset(compoundShape, yCol, 0); + compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; + } + } + }; + samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); + } + template + static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) { + + //const int rowNum = compound->rows(); +// const int columnNum = output->columns(); + permutation->linspace(0); + auto permutationBuf = permutation->bufferAsT(); //dataBuffer()->primaryAsT(); + auto compoundBuf = compound->bufferAsT(); + auto compoundShape = compound->shapeInfo(); + auto permutationShape = permutation->shapeInfo(); + for (auto i = 0; i < rowNum - 1; i++) { + auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); + if (pivotIndex < 0) { + throw std::runtime_error("helpers::luNN_: input matrix is singular."); + } + math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); + swapRows(compoundBuf, compoundShape, i, pivotIndex); + + processColumns(i, rowNum, compoundBuf, compoundShape); + } + } + + template + static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { + auto n = input->sizeAt(-1); + + output->assign(input); // fill up output tensor with zeros + ResultSet outputs = output->allTensorsAlongDimension({-2, -1}); + ResultSet permutations = permutationVectors->allTensorsAlongDimension({-1}); + auto loop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + luNN_(context, outputs.at(i), permutations.at(i), n); + } + }; + samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); + } + + void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES); + } + +// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); template static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) { @@ -175,7 +291,7 @@ namespace helpers { for (int e = 0; e < output->lengthOf(); e++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) matrix.p(row, input->e(k)); - output->p(e, lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr)); + output->p(e, lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr)); } return Status::OK(); @@ -196,7 +312,7 @@ template for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { matrix.p(row, input->e(k)); } - NDArray det = lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); + NDArray det = lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); if (det.e(0) != 0.f) output->p(e, nd4j::math::nd4j_log(nd4j::math::nd4j_abs(det.t(0)))); } @@ -229,7 +345,7 @@ template for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { matrix.p(row++, input->e(k)); } - T det = lup_(context, &matrix, &compound, &permutation).template e(0); + T det = lup_(context, &matrix, &compound, &permutation).template e(0); // FIXME: and how this is going to work on float16? if (nd4j::math::nd4j_abs(det) < T(0.000001)) { @@ -268,13 +384,13 @@ template template static bool checkCholeskyInput_(nd4j::LaunchContext * context, NDArray const* input) { //std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType())); //, block.getWorkspace()); - std::unique_ptr lastMatrixList(input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf()-1})); - for (size_t i = 0; i < lastMatrixList->size(); i++) { - auto thisMatrix = lastMatrixList->at(i); + ResultSet lastMatrixList = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf()-1}); + for (size_t i = 0; i < lastMatrixList.size(); i++) { + auto thisMatrix = lastMatrixList.at(i); // check for symmetric for (Nd4jLong r = 0; r < thisMatrix->rows(); r++) for (Nd4jLong c = 0; c < thisMatrix->columns(); c++) - if (nd4j::math::nd4j_abs(thisMatrix->e(r, c) - lastMatrixList->at(i)->e(c,r)) > T(1.e-6f)) return false; + if (nd4j::math::nd4j_abs(thisMatrix->e(r, c) - lastMatrixList.at(i)->e(c,r)) > DataTypeUtils::min()) return false; NDArray output = NDArrayFactory::create(0., context); if (ND4J_STATUS_OK != determinant(context, thisMatrix, &output)) return false; @@ -343,21 +459,18 @@ template template int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) { - std::unique_ptr tempOutput(input->dup()); - int res = cholesky_(context, input, tempOutput.get(), false); + auto tempOutput = input->dup(); + int res = cholesky_(context, input, &tempOutput, false); if (res != ND4J_STATUS_OK) return res; auto n = input->sizeAt(-1); auto totalCount = output->lengthOf(); std::vector d(n); - std::unique_ptr matricies(tempOutput->allTensorsAlongDimension({input->rankOf()-2, input->rankOf() - 1})); - std::unique_ptr inputMatricies(input->allTensorsAlongDimension({input->rankOf()-2, input->rankOf() - 1})); - for (Nd4jLong e = 0; e < totalCount; e++) { + ResultSet matricies = tempOutput.allTensorsAlongDimension({input->rankOf()-2, input->rankOf() - 1}); - //d[0] = inputMatricies->at(e)->t(0, 0); - for (size_t i = 0; i < n; ++i) { - output->t(e) += nd4j::math::nd4j_log(nd4j::math::nd4j_pow(matricies->at(e)->t(i, i), T(2))); - } + for (Nd4jLong e = 0; e < totalCount; e++) { + for (size_t i = 0; i < n; ++i) + output->t(e) += nd4j::math::nd4j_log(nd4j::math::nd4j_pow(matricies.at(e)->t(i, i), T(2))); } return ND4J_STATUS_OK; } @@ -366,6 +479,11 @@ template BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES); } + int lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_, (context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES); + return Status::OK(); + } + } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp index 399d89e32..fbab49e80 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp @@ -30,11 +30,11 @@ namespace helpers { Nd4jLong N = input->sizeAt(-1); Nd4jLong lastDim = input->rankOf() - 1; Nd4jLong preLastDim = input->rankOf() - 2; - std::unique_ptr listOut(output->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - std::unique_ptr listDiag(input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - for (Nd4jLong e = 0; e < listOut->size(); ++e) { - NDArray* inputMatrix = listDiag->at(e); - NDArray* outputMatrix = listOut->at(e); + ResultSet listOut = output->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); + ResultSet listDiag = input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); + for (Nd4jLong e = 0; e < listOut.size(); ++e) { + NDArray* inputMatrix = listDiag.at(e); + NDArray* outputMatrix = listOut.at(e); if (outputMatrix != inputMatrix) // if not inplace outputMatrix->assign(inputMatrix); if (lowerBand >= 0) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp index e0e487e82..cc43c1866 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp @@ -37,24 +37,21 @@ int _matrixDiagPart(const NDArray* input, NDArray* output) { auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1}); - if (listOut->size() != listDiag->size()) { + if (listOut.size() != listDiag. size()) { nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); return ND4J_STATUS_VALIDATION; } int lastDimension = nd4j::math::nd4j_min(input->sizeAt(-2), input->sizeAt(-1)); // TODO: tune this properlys - int lO = listOut->size(); + int lO = listOut.size(); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) for (int j = 0; j < lastDimension; ++j) - listOut->at(i)->p(j, listDiag->at(i)->e(j, j)); + listOut.at(i)->p(j, listDiag.at(i)->e(j, j)); }; samediff::Threads::parallel_tad(func, 0, lO); - - delete listOut; - delete listDiag; return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp b/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp index b59c16afe..a8a0d919d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp @@ -39,13 +39,11 @@ void meshgrid(nd4j::LaunchContext * context, const std::vector& inArrs inIndices[0] = 1; inIndices[1] = 0; } - - for(int i = 0; i < rank; ++i) { - auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); - for(int j = 0; j < list->size(); ++j) - list->at(j)->assign(inArrs[i]); - delete list; + for(int i = 0; i < rank; ++i) { + auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); + for(int j = 0; j < list.size(); ++j) + list.at(j)->assign(inArrs[i]); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp index 61b6465ba..8d94d23ca 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp @@ -27,7 +27,7 @@ namespace nd4j { namespace ops { namespace helpers { - template + template static void minimumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { auto lambdaX = LAMBDA_TTT(_e, _x, _y) { @@ -43,10 +43,10 @@ namespace helpers { // PWT case case // X gradient - epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); // Y gradient - epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); } else if (y->isScalar()) { T s = y->e(0); @@ -60,8 +60,8 @@ namespace helpers { gradY->assign(tmp); else gradY->assign(0.0f); - - epsNext->applyPairwiseLambda(x, lambdaS, gradX); + + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); } else { // broadcast case @@ -71,8 +71,8 @@ namespace helpers { auto targetShape = epsNext->getShapeAsVector(); - preX->tileToShape(targetShape); - preY->tileToShape(targetShape); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); @@ -81,22 +81,16 @@ namespace helpers { auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); - - - delete preX; - delete preY; } } @@ -116,10 +110,10 @@ namespace helpers { // PWT case case // X gradient - epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); // Y gradient - epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); } else if (y->isScalar()) { T s = y->e(0); @@ -133,8 +127,8 @@ namespace helpers { gradY->assign(tmp); else gradY->assign(0.0f); - - epsNext->applyPairwiseLambda(x, lambdaS, gradX); + + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); } else { // broadcast case @@ -144,8 +138,8 @@ namespace helpers { auto targetShape = epsNext->getShapeAsVector(); - preX->tileToShape(targetShape); - preY->tileToShape(targetShape); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); @@ -154,22 +148,16 @@ namespace helpers { auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); - - - delete preX; - delete preY; } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp b/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp index 8c5332be6..dcca5075e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp @@ -51,12 +51,12 @@ namespace helpers { SpecialMethods::sortTadGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), lastDims.data(), lastDims.size(), pack.primaryShapeInfo(), pack.primaryOffsets(), reverse); - std::unique_ptr rows(sortedVals.allTensorsAlongDimension(lastDims)); + ResultSet rows = sortedVals.allTensorsAlongDimension(lastDims); Nd4jLong oL = output->lengthOf(); auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - auto row = rows->at(e); + auto row = rows.at(e); output->p(e, row->e(n)); } }; @@ -70,7 +70,7 @@ namespace helpers { } BUILD_SINGLE_TEMPLATE(template void nthElementFunctor_, (NDArray* input, Nd4jLong n, NDArray* output, bool reverse), LIBND4J_TYPES); - + } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp b/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp index 5c1f3c28d..fa8061e54 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp @@ -30,8 +30,8 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////// template static void _percentile(const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { - - const int inputRank = input.rankOf(); + + const int inputRank = input.rankOf(); if(axises.empty()) for(int i=0; i& auto listOfSubArrs = input.allTensorsAlongDimension(axises); - - std::vector shapeOfSubArr(listOfSubArrs->at(0)->rankOf()); + + std::vector shapeOfSubArr(listOfSubArrs.at(0)->rankOf()); for(int i=0; iat(0)->shapeOf()[i]; + shapeOfSubArr[i] = listOfSubArrs.at(0)->shapeOf()[i]; auto flattenedArr = NDArrayFactory::create('c', shapeOfSubArr, input.dataType(), input.getContext()); const int len = flattenedArr.lengthOf(); - + const float fraction = 1.f - q / 100.; Nd4jLong position = 0; - + switch(interpolation) { case 0: // lower position = static_cast(math::nd4j_ceil((len - 1) * fraction)); @@ -67,15 +67,13 @@ static void _percentile(const NDArray& input, NDArray& output, std::vector& // FIXME: our sort impl should be used instead, so this operation might be implemented as generic // FIXME: parallelism ! - for(int i=0; isize(); ++i) { - + for(int i=0; i(flattenedArr.getBuffer()); - flattenedArr.assign(listOfSubArrs->at(i)); + flattenedArr.assign(listOfSubArrs.at(i)); std::sort(buff, buff + len); output.p(i, flattenedArr.e(position)); } - - delete listOfSubArrs; } void percentile(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp b/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp index f46346876..43c65f14b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp @@ -101,17 +101,14 @@ namespace nd4j { static void prefix_(scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse) { auto xTads = x->allTensorsAlongDimension(dims); auto zTads = z->allTensorsAlongDimension(dims); - auto t = xTads->size(); + auto t = xTads.size(); for (int e = 0; e < t; e++) { - auto tx = xTads->at(e); - auto tz = zTads->at(e); + auto tx = xTads.at(e); + auto tz = zTads.at(e); prefix_(op, tx->buffer(), tx->shapeInfo(), tz->buffer(), tz->shapeInfo(), exclusive, reverse); } - - delete xTads; - delete zTads; }; template diff --git a/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp b/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp new file mode 100644 index 000000000..293518be6 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp @@ -0,0 +1,31 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) { + array.printIndexedBuffer(message.c_str()); + } + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp new file mode 100644 index 000000000..90b69ca6f --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp @@ -0,0 +1,133 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// +#include +#include +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + template + NDArray matrixMinor(NDArray& in, Nd4jLong col) { + NDArray m = in.ulike(); + m.setIdentity(); + m({col, m.rows(), col, m.columns()}).assign(in({col, m.rows(), col, m.columns()})); + + return m; + } + +/* m = I - v v^T */ + template + NDArray vmul(NDArray const& v, int n) + { + NDArray res('c', {n,n}, v.dataType()); // x = matrix_new(n, n); + T const* vBuf = v.getDataBuffer()->primaryAsT(); + T* resBuf = res.dataBuffer()->primaryAsT(); + auto interloop = PRAGMA_THREADS_FOR_2D { + for (int i = start_x; i < n; i += inc_x) + for (int j = start_y; j < n; j += inc_y) + resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0)); + }; + + samediff::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1); + return res; + } + + template + void qrSingle(NDArray* matrix, NDArray* Q, NDArray* R, bool const fullMatricies) { + Nd4jLong M = matrix->sizeAt(-2); + Nd4jLong N = matrix->sizeAt(-1); + auto resQ = fullMatricies?Q->ulike():NDArrayFactory::create(matrix->ordering(), {M,M}, Q->getContext()); + auto resR = fullMatricies?R->ulike():matrix->ulike(); + std::vector q(M); + + NDArray z = *matrix; + NDArray e('c', {M}, DataTypeUtils::fromT()); // two internal buffers and scalar for squared norm + + for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number + e.nullify(); + z = matrixMinor(z, k); // minor computing for current column with given matrix z (initally is a input matrix) +// z.printIndexedBuffer("Minor!!!"); + + auto currentColumn = z({0, 0, k, k + 1}); // retrieve k column from z to x buffer + auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0}); + if (matrix->t(k,k) > T(0.f)) // negate on positive matrix diagonal element + norm *= T(-1.f);//.applyTransform(transform::Neg, nullptr, nullptr); //t(0) = -norm.t(0); + //e.t(k) = T(1.f); // e - is filled by 0 vector except diagonal element (filled by 1) + //auto tE = e; + //tE *= norm; +// norm.printIndexedBuffer("Norm!!!"); + e.p(k, norm); + e += currentColumn;// e += tE; // e[i] = x[i] + a * e[i] for each i from 0 to n - 1 + auto normE = e.reduceAlongDimension(reduce::Norm2, {0}); + e /= normE; + q[k] = vmul(e, M); + auto qQ = z.ulike(); + MmulHelper::matmul(&q[k], &z, &qQ, false, false); + z = std::move(qQ); + } + resQ.assign(q[0]); // +// MmulHelper::matmul(&q[0], matrix, &resR, false, false); + for (int i = 1; i < N && i < M - 1; i++) { + auto tempResQ = resQ; + MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); // use mmulMxM? + resQ = std::move(tempResQ); + } + MmulHelper::matmul(&resQ, matrix, &resR, false, false); + // resR *= -1.f; + resQ.transposei(); + if (fullMatricies) { + Q->assign(resQ); + R->assign(resR); + } + else { + Q->assign(resQ({0,0, 0, N})); + R->assign(resR({0,N, 0, 0})); + } + } + + template + void qr_(NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { + Nd4jLong lastDim = input->rankOf() - 1; + Nd4jLong preLastDim = input->rankOf() - 2; + ResultSet listOutQ(outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listOutR(outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listInput(input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + auto batching = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch += increment) { + //qr here + qrSingle(listInput.at(batch), listOutQ.at(batch), listOutR.at(batch), fullMatricies); + } + }; + + samediff::Threads::parallel_tad(batching, 0, listOutQ.size(), 1); + + } + + void qr(nd4j::LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { + BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (input, outputQ, outputR, fullMatricies), FLOAT_TYPES); + } + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp index 3f9788330..ad04db307 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp @@ -24,6 +24,8 @@ //#include #include #include +#include +#include namespace nd4j { namespace ops { @@ -46,8 +48,8 @@ namespace helpers { NDArray alphaBroadcasted(broadcasted, alpha->dataType(), false, context); NDArray betaBroadcasted(broadcasted, beta->dataType(), false, context); - copyAlpha = (alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), alpha)); - copyBeta = (betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), beta)); + copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); + copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); } // bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c'; @@ -150,6 +152,61 @@ namespace helpers { void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES); } + + // used https://en.wikipedia.org/wiki/Categorical_distribution + // methods: gumbel trick + softmax + argmax + template + void fillRandomMultiNomial_(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { + + const Tx* x = input.bufferAsT(); + Tz* z = output.bufferAsT(); + + Tx minVal = DataTypeUtils::min(); + Tx maxVal = 1.0; + + auto dimA = (0 == dimC) ? 1 : 0; + const Nd4jLong batchValue = output.sizeAt(dimC); + const Nd4jLong numOfClassX = input.sizeAt(dimA); + + const Nd4jLong zDimAstride = output.stridesOf()[dimA]; + const Nd4jLong xDimAstride = input.stridesOf()[dimA]; + const Nd4jLong zDimCstride = output.stridesOf()[dimC]; + const Nd4jLong xDimCstride = input.stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR_2D{ + for (auto nBatchIndex = start_x; nBatchIndex < stop_x; nBatchIndex += inc_x) { + for (auto nSampleIndexInBatch = start_y; nSampleIndexInBatch < stop_y; nSampleIndexInBatch += inc_y) { + + const Tx* xTad = x + (nBatchIndex * xDimCstride); + Tz* zTad = z + (nBatchIndex * zDimCstride); + Tz& arg = zTad[nSampleIndexInBatch * zDimAstride]; + Tx Max = -minVal; + + auto nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; + auto nClassesPerSample = nSampleIndexInBatch * numOfClassX; + for (auto nClass = 0; nClass < numOfClassX; nClass += 1) { + auto nIndex = nSamplesPerBatch + nClassesPerSample + nClass; + auto unifornLog = nd4j::math::nd4j_log(-nd4j::math::nd4j_log(rng.relativeT(nIndex, minVal, maxVal))); + Tx tValue = (xTad[nClass * xDimAstride] - unifornLog); + if (tValue > Max) { + Max = tValue; + arg = nClass; + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1); + rng.rewindH(output.lengthOf()*numOfClassX); + + return; + } + + void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillRandomMultiNomial_, (context, rng, input, output, numOfSamples, dimC), FLOAT_TYPES, INDEXING_TYPES); + } + } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp index 9f424606d..9ee906bd5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp @@ -171,27 +171,21 @@ static void reverseSequence_(nd4j::LaunchContext * context, const NDArray* input auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); - for(int i = 0; i < inSubArrsSet->size(); ++i) { + for(int i = 0; i < inSubArrsSet.size(); ++i) { Nd4jLong numOfElemsToReverse = seqLengths->e(i); if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { - outSubArrsSet->at(i)->assign(inSubArrsSet->at(i)); + outSubArrsSet.at(i)->assign(inSubArrsSet.at(i)); } else { - auto inInnerSet = inSubArrsSet->at(i)->allTensorsAlongDimension({seqDim}); - auto outInnerSet = outSubArrsSet->at(i)->allTensorsAlongDimension({seqDim}); - for(int j = 0; j < inInnerSet->size(); ++j) - helpers::reverseArray(context, inInnerSet->at(j)->getBuffer(), inInnerSet->at(j)->getShapeInfo(), outInnerSet->at(j)->getBuffer(), outInnerSet->at(j)->getShapeInfo(), numOfElemsToReverse); - - delete inInnerSet; - delete outInnerSet; + auto inInnerSet = inSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + auto outInnerSet = outSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + for(int j = 0; j < inInnerSet.size(); ++j) + helpers::reverseArray(context, inInnerSet.at(j)->getBuffer(), inInnerSet.at(j)->getShapeInfo(), outInnerSet.at(j)->getBuffer(), outInnerSet.at(j)->getShapeInfo(), numOfElemsToReverse); } } - delete inSubArrsSet; - delete outSubArrsSet; } - } void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { @@ -209,14 +203,11 @@ void reverse(nd4j::LaunchContext * context, const NDArray* input, NDArray* outpu NDArray *subArrIn, *subArrOut; - for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size() - subArrIn = listIn->at(i); - subArrOut = listOut->at(i); + for(int i = 0; i < listIn.size(); ++i) { // listIn.size() = listOut.size() + subArrIn = listIn.at(i); + subArrOut = listOut.at(i); BUILD_SINGLE_SELECTOR(input->dataType(), helpers::reverseArray, (context, subArrIn->getBuffer(), subArrIn->getShapeInfo(), subArrOut->getBuffer(), subArrOut->getShapeInfo()), LIBND4J_TYPES); } - - delete listOut; - delete listIn; } BUILD_SINGLE_TEMPLATE(template void reverseSequence_, (nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim), LIBND4J_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp index b3b65f816..8bfc1ca1a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp @@ -40,8 +40,8 @@ namespace helpers { if (actualShift) { int shiftCount = fullLen / actualShift - 1; - int remainShift = fullLen % actualShift; - + int remainShift = fullLen % actualShift; + // stage 1) swap last actualShift elements with first ones. //PRAGMA_OMP_PARALLEL_FOR //_IF(actualShift > Environment::getInstance()->elementwiseThreshold()) for (int e = 0; e < actualShift; ++e) { @@ -70,7 +70,7 @@ namespace helpers { output->p(sourceIndex, _e0); } } - + // stage 3) swap remainer of items. if (remainShift && shiftCount) for (int i = actualShift; i < 2 * actualShift; ++i) { @@ -94,9 +94,9 @@ namespace helpers { for (size_t i = 0; i < axes.size(); i++) { int axe = axes[i]; if (axe == source->rankOf() - 1) {// last dimension - std::unique_ptr listOfTensors(source->allTensorsAlongDimension({axe})); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension({axe})); - int fullLen = listOfTensors->size(); + ResultSet listOfTensors = source->allTensorsAlongDimension({axe}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); + int fullLen = listOfTensors.size(); int theShift = shifts[i]; if (theShift > 0) { theShift %= fullLen; @@ -105,7 +105,7 @@ namespace helpers { theShift -= fullLen * (theShift / fullLen - 1); } for (int k = 0; k < fullLen; k++) { - rollFunctorLinear(context, listOfTensors->at(k), listOfOutTensors->at(k), theShift, true); + rollFunctorLinear(context, listOfTensors.at(k), listOfOutTensors.at(k), theShift, true); } } else { @@ -113,10 +113,10 @@ namespace helpers { for (int i = 0; i < dims.size(); ++i) dims[i] = axe + 1 + i; - std::unique_ptr listOfTensors(source->allTensorsAlongDimension({dims})); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension({dims})); + ResultSet listOfTensors = source->allTensorsAlongDimension({dims}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({dims}); // - int fullLen = listOfTensors->size(); + int fullLen = listOfTensors.size(); int sizeAt = input->sizeAt(axe); int theShift = shifts[i]; @@ -131,16 +131,16 @@ namespace helpers { if (theShift) { for (int dim = 0; dim < fullLen / sizeAt; ++dim) { for (int e = theShift; e < sizeAt - theShift; ++e) { - auto sourceM = listOfTensors->at(dim * sizeAt + e - theShift); - auto targetM = listOfOutTensors->at(dim * sizeAt + e); + auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift); + auto targetM = listOfOutTensors.at(dim * sizeAt + e); sourceM->swapUnsafe(*targetM); } - + for (int e = 0; e < theShift; ++e) { int sourceIndex = dim * sizeAt + sizeAt - theShift + e; - auto sourceM = listOfTensors->at(sourceIndex); - auto targetM = listOfOutTensors->at(dim * sizeAt + e); - + auto sourceM = listOfTensors.at(sourceIndex); + auto targetM = listOfOutTensors.at(dim * sizeAt + e); + sourceM->swapUnsafe(*targetM); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp index 9ae191c76..a3f0c01be 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp @@ -83,7 +83,7 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind Nd4jLong idx = indices.e(i); NDArray out = output({idx, idx + 1}); - out.applyPairwiseTransform(op, updates.e(i), nullptr); + out.applyPairwiseTransform(op, updates.e(i)); } }; @@ -103,7 +103,7 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind NDArray outSubArr = output(indices.e(i), std::vector({0})); NDArray updSubArr = updates(i, dimsToExcludeUpd); - outSubArr.applyPairwiseTransform(op, updSubArr, nullptr); + outSubArr.applyPairwiseTransform(op, updSubArr); } }; @@ -150,7 +150,7 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i NDArray outSubArr = output(idxRangeOut); NDArray updSubArr = updates(i, dimsToExcludeUpd); - outSubArr.applyPairwiseTransform(op, updSubArr, nullptr); + outSubArr.applyPairwiseTransform(op, updSubArr); } }; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp index 2884107f3..e20145735 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp @@ -22,7 +22,7 @@ #include #include #include -#include +#include namespace nd4j { namespace ops { @@ -56,31 +56,29 @@ namespace helpers { auto numOfClasses = output->sizeAt(0); // number of classes std::vector> outputs(numOfClasses); - auto maxT = listOfOutTensors->at(idx); + auto maxT = listOfOutTensors.at(idx); //int pos = 0; - maxT->assign(listOfTensors->at(0)); + maxT->assign(listOfTensors.at(0)); for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { for (Nd4jLong e = 0; e < maxT->lengthOf(); e++) { - maxT->t(e) = nd4j::math::nd4j_max(maxT->t(e), listOfTensors->at(i)->t(e)); + maxT->t(e) = nd4j::math::nd4j_max(maxT->t(e), listOfTensors.at(i)->t(e)); } } else { idx = indices->e(i); - maxT = listOfOutTensors->at(idx); - maxT->assign(listOfTensors->at(i)); + maxT = listOfOutTensors.at(idx); + maxT->assign(listOfTensors.at(i)); } } - delete listOfTensors; - delete listOfOutTensors; } } - // segmen min + // segmen min template static void segmentMinFunctor_(NDArray* input, NDArray* indices, NDArray* output) { //int numClasses = output->sizeAt(0); @@ -91,7 +89,7 @@ namespace helpers { for (int e = 1; e < indices->lengthOf(); e++) { if (idx == indices->e(e)) { - // min + // min val = nd4j::math::nd4j_min(val, input->t(e)); } else { @@ -104,27 +102,27 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors( input->allTensorsAlongDimension(restDims) ); - std::unique_ptr listOfOutTensors( output->allTensorsAlongDimension(restDims) ); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); int numOfClasses = output->sizeAt(0); // number of classes std::vector> outputs(numOfClasses); - auto minT = listOfOutTensors->at(idx); + auto minT = listOfOutTensors.at(idx); int pos = 0; - minT->assign(listOfTensors->at(0)); + minT->assign(listOfTensors.at(0)); for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { for (int e = 0; e < minT->lengthOf(); e++) { - minT->p(e, nd4j::math::nd4j_min(minT->e(e), listOfTensors->at(i)->e(e))); + minT->p(e, nd4j::math::nd4j_min(minT->e(e), listOfTensors.at(i)->e(e))); } } else { idx = indices->e(i); - minT = listOfOutTensors->at(idx); - minT->assign(listOfTensors->at(i)); + minT = listOfOutTensors.at(idx); + minT->assign(listOfTensors.at(i)); } } } @@ -142,7 +140,7 @@ namespace helpers { for (int e = 0; e < indices->lengthOf(); e++) { if (idx == indices->e(e)) { - // mean + // mean val += input->e(e); count++; } @@ -163,16 +161,16 @@ namespace helpers { int numOfClasses = output->sizeAt(0); // number of classes std::vector> outputs(numOfClasses); - auto meanT = listOfOutTensors->at(idx); + auto meanT = listOfOutTensors.at(idx); int count = 1; auto meanV = meanT->dup(); - meanV->assign(listOfTensors->at(0)); + meanV.assign(listOfTensors.at(0)); for (int i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - meanV->p(e, meanV->e(e) + listOfTensors->at(i)->e(e)); + meanV.p(e, meanV.e(e) + listOfTensors.at(i)->e(e)); } }; samediff::Threads::parallel_for(func, 0, meanT->lengthOf()); @@ -181,17 +179,14 @@ namespace helpers { } else { //meanT->assign(meanV); - meanV->applyScalar(scalar::Divide, count, meanT, nullptr); + meanV.applyScalar(scalar::Divide, count, *meanT); idx = indices->e(i); - meanT = listOfOutTensors->at(idx); - meanV->assign(listOfTensors->at(i)); + meanT = listOfOutTensors.at(idx); + meanV.assign(listOfTensors.at(i)); count = 1; } - meanV->applyScalar(scalar::Divide, count, meanT, nullptr); + meanV.applyScalar(scalar::Divide, count, *meanT); } - delete meanV; - delete listOfTensors; - delete listOfOutTensors; } } @@ -205,7 +200,7 @@ namespace helpers { int count = 0; for (int e = 0; e < indices->lengthOf(); e++) { if (idx == indices->e(e)) { - // sum + // sum val += input->t(e); } else { @@ -223,25 +218,23 @@ namespace helpers { int numOfClasses = output->sizeAt(0); // number of classes std::vector> outputs(numOfClasses); - auto sumT = listOfOutTensors->at(idx); + auto sumT = listOfOutTensors.at(idx); for (int i = 0; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - sumT->p(e, sumT->e(e) + listOfTensors->at(i)->e(e)); + sumT->p(e, sumT->e(e) + listOfTensors.at(i)->e(e)); } }; samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); } else { idx = indices->e(i); - sumT = listOfOutTensors->at(idx); - sumT->assign(listOfTensors->at(i)); + sumT = listOfOutTensors.at(idx); + sumT->assign(listOfTensors.at(i)); } } - delete listOfTensors; - delete listOfOutTensors; } } @@ -257,7 +250,7 @@ namespace helpers { for (int e = 1; e < indices->lengthOf(); e++) { if (idx == indices->e(e)) { - // sum + // sum val *= input->e(e); } else { @@ -274,25 +267,23 @@ namespace helpers { auto listOfOutTensors = output->allTensorsAlongDimension(restDims); int numOfClasses = output->sizeAt(0); // number of classes - auto sumT = listOfOutTensors->at(idx); - sumT->assign(listOfTensors->at(0)); + auto sumT = listOfOutTensors.at(idx); + sumT->assign(listOfTensors.at(0)); for (int i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - sumT->p(e, sumT->e(e) * listOfTensors->at(i)->e(e)); + sumT->p(e, sumT->e(e) * listOfTensors.at(i)->e(e)); } }; samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); } else { idx = indices->e(i); - sumT = listOfOutTensors->at(idx); - sumT->assign(listOfTensors->at(i)); + sumT = listOfOutTensors.at(idx); + sumT->assign(listOfTensors.at(i)); } } - delete listOfTensors; - delete listOfOutTensors; } } @@ -380,24 +371,23 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); T maxVal = DataTypeUtils::max(); output->assign(-maxVal); for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); for (Nd4jLong idx = 1; idx < fi->second.size(); ++idx) { - auto maxT = listOfTensors->at(fi->second.at(idx)); + auto maxT = listOfTensors.at(fi->second.at(idx)); for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { T val = nd4j::math::nd4j_max(maxT->e(e), outputT->e(e)); outputT->p(e, val); } } - //outputT->assign(maxT); } } } @@ -433,17 +423,17 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); T maxVal = DataTypeUtils::max(); output->assign(maxVal); for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); for (Nd4jLong idx = 1; idx < fi->second.size(); ++idx) { - auto minT = listOfTensors->at(fi->second.at(idx)); + auto minT = listOfTensors.at(fi->second.at(idx)); for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { outputT->t(e) = nd4j::math::nd4j_min(minT->t(e), outputT->t(e)); @@ -485,17 +475,17 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); // FIXME: parallelism here? for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); Nd4jLong loopSize = fi->second.size(); for (Nd4jLong idx = 1; idx < loopSize; ++idx) { - auto current = listOfTensors->at(fi->second.at(idx)); + auto current = listOfTensors.at(fi->second.at(idx)); *outputT += *current; } (*outputT) /= double(fi->second.size()); @@ -524,17 +514,17 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); Nd4jLong loop_size = fi->second.size(); // FIXME: parallelism here? for (Nd4jLong idx = 1; idx < loop_size; ++idx) { - auto current = listOfTensors->at(fi->second.at(idx)); + auto current = listOfTensors.at(fi->second.at(idx)); *(outputT) += *current; } //outputT->assign(maxT); @@ -564,14 +554,14 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); for (Nd4jLong idx = 1; idx < fi->second.size(); ++idx) { - auto current = listOfTensors->at(fi->second.at(idx)); + auto current = listOfTensors.at(fi->second.at(idx)); *outputT *= *current; } @@ -603,14 +593,14 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); for (Nd4jLong idx = 1; idx < fi->second.size(); ++idx) { - auto current = listOfTensors->at(fi->second.at(idx)); + auto current = listOfTensors.at(fi->second.at(idx)); *outputT += *current; } //outputT->assign(maxT); @@ -630,14 +620,14 @@ namespace helpers { //int numOfClasses = gradOut->sizeAt(0); // if input is a vector: (as if in doc sample) auto tempRes = gradOut->dup(); - segmentMaxFunctor_(input, indices, tempRes); + segmentMaxFunctor_(input, indices, &tempRes); if (input->isVector()) { Nd4jLong loop_size = input->lengthOf(); auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto classNum = indices->e(e); - if (nd4j::math::nd4j_abs(tempRes->e(classNum) - input->e(e)) <= T(1.e-6)) + if (nd4j::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) <= T(1.e-6)) output->p(e, gradOut->e(classNum)); } }; @@ -646,23 +636,23 @@ namespace helpers { else { std::vector restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - //int numOfClasses = tempRes->sizeAt(0); // number of classes + //int numOfClasses = tempRes.sizeAt(0); // number of classes //std::vector> outputs(numOfClasses); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); for (uint64_t e = 0; e < current->lengthOf(); e++) { - if (nd4j::math::nd4j_abs(listOfBPTensors->at(classNum)->e(e) - current->e(e)) <= T(1.e-6)) + if (nd4j::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) <= T(1.e-6)) currentOut->p(e, currentGradOut->e(e)); } } @@ -670,7 +660,7 @@ namespace helpers { samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); } - delete tempRes; + return ND4J_STATUS_OK; } @@ -681,13 +671,13 @@ namespace helpers { // segmen min int segmentMinFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - std::unique_ptr tempRes(gradOut->dup()); - segmentMinFunctor(context, input, indices, tempRes.get()); + NDArray tempRes = gradOut->dup(); + segmentMinFunctor(context, input, indices, &tempRes); if (input->isVector()) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto classNum = indices->e(e); - if (nd4j::math::nd4j_abs(tempRes->e(classNum) - input->e(e)) < 1.e-5) + if (nd4j::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) output->p(e, gradOut->e(classNum)); } }; @@ -696,12 +686,12 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - //int numOfClasses = tempRes->sizeAt(0); // number of classes + //int numOfClasses = tempRes.sizeAt(0); // number of classes //std::vector> outputs(numOfClasses); output->assign(0.); int pos = 0; @@ -709,12 +699,12 @@ namespace helpers { auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); for (int e = 0; e < current->lengthOf(); e++) { - if (nd4j::math::nd4j_abs(listOfBPTensors->at(classNum)->e(e) - current->e(e)) < + if (nd4j::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < 1.e-5) currentOut->p(e, currentGradOut->e(e)); } @@ -749,20 +739,18 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); - - //int numOfClasses = tempRes->sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); +; int pos = 0; //auto func = [&](uint64_t thread_id, uint64_t start, uint64_t stop, uint64_t increment) -> void { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); for (int e = 0; e < current->lengthOf(); e++) { currentOut->p(e, currentGradOut->e(e) / classCount.at(classNum)); @@ -788,16 +776,16 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); currentOut->assign(currentGradOut); } @@ -810,31 +798,31 @@ namespace helpers { int segmentProdFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { auto tempRes = gradOut->dup(); - segmentProdFunctor(context, input, indices, tempRes); + segmentProdFunctor(context, input, indices, &tempRes); if (input->isVector()) { for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) * tempRes->e(classNum)/ input->e(e)); + output->p(e, gradOut->e(classNum) * tempRes.e(classNum)/ input->e(e)); } } else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - //int numOfClasses = tempRes->sizeAt(0); // number of classes + //int numOfClasses = tempRes.sizeAt(0); // number of classes //std::vector> outputs(numOfClasses); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); - auto currentFFOut = listOfBPTensors->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + auto currentFFOut = listOfBPTensors.at(classNum); currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); } @@ -842,7 +830,7 @@ namespace helpers { //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } - delete tempRes; + return ND4J_STATUS_OK; } @@ -855,35 +843,35 @@ namespace helpers { // int numOfClasses = gradOut->sizeAt(0); // if input is a vector: (as if in doc sample) auto tempRes = gradOut->dup(); - unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, tempRes); + unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, &tempRes); if (input->isVector()) { for (Nd4jLong e = 0; e < input->lengthOf(); ++e) { Nd4jLong classNum = indices->e(e); - if (nd4j::math::nd4j_abs(tempRes->e(classNum) - input->e(e)) < 1.e-5) + if (nd4j::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) output->p(e, gradOut->e(classNum)); } } else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (int i = 0; i < indices->lengthOf(); i++) { Nd4jLong classNum = indices->e(i); - NDArray* current = listOfTensors->at(i); - NDArray* currentOut = listOfOutTensors->at(i); - NDArray* currentGradOut = listOfGradOuts->at(classNum); + NDArray* current = listOfTensors.at(i); + NDArray* currentOut = listOfOutTensors.at(i); + NDArray* currentGradOut = listOfGradOuts.at(classNum); for (int e = 0; e < current->lengthOf(); e++) { - if (nd4j::math::nd4j_abs(listOfBPTensors->at(classNum)->e(e) - current->e(e)) < 1.e-5) + if (nd4j::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < 1.e-5) currentOut->p(e, currentGradOut->e(e)); } } } - delete tempRes; + return ND4J_STATUS_OK; } @@ -895,13 +883,13 @@ namespace helpers { template static int unsortedSegmentMinFunctorBP_(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { auto tempRes = gradOut->dup(); - unsortedSegmentMinFunctor(context, input, indices, numOfClasses, tempRes); + unsortedSegmentMinFunctor(context, input, indices, numOfClasses, &tempRes); if (input->isVector()) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto classNum = indices->e(e); - if (nd4j::math::nd4j_abs(tempRes->t(classNum) - input->t(e)) < 1.e-6) + if (nd4j::math::nd4j_abs(tempRes.t(classNum) - input->t(e)) < 1.e-6) output->t(e) = gradOut->t(classNum); } }; @@ -911,20 +899,20 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); for (int e = 0; e < current->lengthOf(); e++) { - if (nd4j::math::nd4j_abs(listOfBPTensors->at(classNum)->t(e) - current->t(e)) < 1.e-6) + if (nd4j::math::nd4j_abs(listOfBPTensors.at(classNum)->t(e) - current->t(e)) < 1.e-6) currentOut->t(e) = currentGradOut->t(e); } } @@ -932,7 +920,7 @@ namespace helpers { //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } - delete tempRes; + return ND4J_STATUS_OK; } @@ -963,15 +951,15 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (int i = 0; i < indices->lengthOf(); i++) { Nd4jLong classNum = indices->e(i); - NDArray* current = listOfTensors->at(i); - NDArray* currentOut = listOfOutTensors->at(i); - NDArray* currentGradOut = listOfGradOuts->at(classNum); + NDArray* current = listOfTensors.at(i); + NDArray* currentOut = listOfOutTensors.at(i); + NDArray* currentGradOut = listOfGradOuts.at(classNum); currentOut->assign(*currentGradOut / double(classCount[classNum])); } } @@ -991,15 +979,15 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); currentOut->assign(currentGradOut); } @@ -1011,14 +999,14 @@ namespace helpers { } int unsortedSegmentProdFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto tempRes = gradOut->dup(); - unsortedSegmentProdFunctor(context, input, indices, numOfClasses, tempRes); + auto tempRes = gradOut->dup(); + unsortedSegmentProdFunctor(context, input, indices, numOfClasses, &tempRes); if (input->isVector()) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto classNum = indices->e(e); - output->p(e, gradOut->e(classNum) * tempRes->e(classNum) / input->e(e)); + output->p(e, gradOut->e(classNum) * tempRes.e(classNum) / input->e(e)); } }; @@ -1027,18 +1015,18 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); - auto currentFFOut = listOfBPTensors->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + auto currentFFOut = listOfBPTensors.at(classNum); currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); } @@ -1046,7 +1034,7 @@ namespace helpers { //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } - delete tempRes; + return Status::OK(); } @@ -1076,16 +1064,16 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfGradOuts =gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors =input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors =output->allTensorsAlongDimension(restDims); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); for (int e = 0; e < current->lengthOf(); e++) { currentOut->p(e, currentGradOut->e(e) / nd4j::math::nd4j_sqrt(classCount[classNum])); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp index 7a9b77b66..4b54c7362 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp @@ -29,7 +29,7 @@ namespace nd4j { return x >> shift; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -42,7 +42,7 @@ namespace nd4j { return x << shift; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -56,7 +56,7 @@ namespace nd4j { return x >> shift | x << step; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -70,7 +70,7 @@ namespace nd4j { return x << shift | x >> step; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp index 1fea14824..642dd37da 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp @@ -34,7 +34,7 @@ static FORCEINLINE NDArray activation(const NDArray& arr) { // return (const_cast&>(arr)).template transform>(); auto result = NDArray(&arr, false, arr.getContext()); - (const_cast(arr)).applyTransform(transform::Tanh, &result); + (const_cast(arr)).applyTransform(transform::Tanh, result); return result; } @@ -125,7 +125,7 @@ static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray // x = x * mask if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask // U = x * w NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] @@ -212,7 +212,7 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr // x = x * mask if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask // U = x * w NDArray wi = mmul(*x, *w); // [time x bS x 2*K] * [2*K x 6*K] = [time x bS x 6*K] @@ -306,7 +306,7 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr samediff::Threads::parallel_tad(func, 0, ncols); // gradB - gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}); // [4*K] + gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] // gradW x->permutei({0, 2, 1}); // [time x bS x 2*K] -> [time x 2*K x bS] diff --git a/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp b/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp index b974a236b..db9b6afff 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp @@ -47,15 +47,13 @@ static void stack_(const std::vector& inArrs, NDArray* outArr, c std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(outArr->rankOf(), {dim}); auto list = outArr->allTensorsAlongDimension(dimsToExclude); // list.size() == block.width() - int listSize = list->size(); + int listSize = list.size(); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) - list->at(i)->assign(inArrs[i]); + list.at(i)->assign(inArrs[i]); }; samediff::Threads::parallel_tad(func, 0, listSize); - - delete list; } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp index 35615287b..9d755f6b6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp @@ -221,26 +221,26 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh const T almostZero = DataTypeUtils::min(); T maxElem; if(len == 1) - maxElem = math::nd4j_abs(diagInterval->template e(0)); + maxElem = math::nd4j_abs(diagInterval.template e(0)); else - maxElem = (*diagInterval)({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); + maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e(0); T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); T epsBig = (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); - if(diagInterval->template e(0) < epsBig) - diagInterval->p(Nd4jLong(0), epsBig); + if(diagInterval.template e(0) < epsBig) + diagInterval.p(Nd4jLong(0), epsBig); for(int i=1; i < len; ++i) if(math::nd4j_abs(colVec0->template e(i)) < eps) colVec0->p(i, 0.f); for(int i=1; i < len; i++) - if(diagInterval->template e(i) < epsBig) { + if(diagInterval.template e(i) < epsBig) { deflation1(col1, shift, i, len); for(int i = 0; i < len; ++i) - diagInterval->p(i, _m.e(col1+shift+i,col1+shift+i)); + diagInterval.p(i, _m.e(col1+shift+i,col1+shift+i)); } { @@ -259,7 +259,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh int p = 1; for(int i=1; i(diagInterval->template e(i)) < almostZero) + if(math::nd4j_abs(diagInterval.template e(i)) < almostZero) permut[p++] = i; int k = 1, m = ind+1; @@ -269,7 +269,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh permut[p] = m++; else if(m >= len) permut[p] = k++; - else if(diagInterval->template e(k) < diagInterval->template e(m)) + else if(diagInterval.template e(k) < diagInterval.template e(m)) permut[p] = m++; else permut[p] = k++; @@ -279,7 +279,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh if(totDefl) { for(int i=1; i(diagInterval->template e(ki)) < almostZero || diagInterval->template e(0) < diagInterval->template e(ki)) + if(math::nd4j_abs(diagInterval.template e(ki)) < almostZero || diagInterval.template e(0) < diagInterval.template e(ki)) permut[i-1] = permut[i]; else { permut[i-1] = 0; @@ -301,10 +301,10 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh const int ki = permut[len - (totDefl ? i+1 : i)]; const int jac = tCol[ki]; - T _e0 = diagInterval->template e(jac); + T _e0 = diagInterval.template e(jac); //math::nd4j_swap(diagInterval)(i), (*diagInterval)(jac)); - diagInterval->p(jac, diagInterval->template e(i)); - diagInterval->p(i, _e0); + diagInterval.p(jac, diagInterval.template e(i)); + diagInterval.p(i, _e0); if(i!=0 && jac!=0) { _e0 = colVec0->template e(jac); @@ -349,12 +349,12 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh { int i = len-1; - while(i > 0 && (math::nd4j_abs(diagInterval->template e(i)) < almostZero || math::nd4j_abs(colVec0->template e(i)) < almostZero)) + while(i > 0 && (math::nd4j_abs(diagInterval.template e(i)) < almostZero || math::nd4j_abs(colVec0->template e(i)) < almostZero)) --i; for(; i > 1; --i) { - if( (diagInterval->template e(i) - diagInterval->template e(i-1)) < DataTypeUtils::eps()*maxElem ) { - if (math::nd4j_abs(diagInterval->template e(i) - diagInterval->template e(i-1)) >= epsBig) + if( (diagInterval.template e(i) - diagInterval.template e(i-1)) < DataTypeUtils::eps()*maxElem ) { + if (math::nd4j_abs(diagInterval.template e(i) - diagInterval.template e(i-1)) >= epsBig) throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !"); deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len); } @@ -362,7 +362,6 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh } delete colVec0; - delete diagInterval; } @@ -606,9 +605,7 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA const T almostZero = DataTypeUtils::min(); auto col0 = _m({col1, col1+size, col1, col1+1}, true); - auto diagP = _m({col1, col1+size, col1, col1+size}, true).diagonal('c'); - auto diag = *diagP; - delete diagP; + auto diag = static_cast(_m({col1, col1+size, col1, col1+size}, true).diagonal('c')); diag.p(Nd4jLong(0), T(0)); singVals = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); @@ -727,8 +724,7 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true); temp.assign(0.); auto diag = _m.diagonal('c'); - (*diag)({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); - delete diag; + diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); return; } @@ -786,14 +782,10 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif temp.assign(_u({col1, col1+k+1, i, i+1}, true)); } - auto temp1 = _u({col1,col1+k+1, col1,col1+1}, true); - temp1.assign(q1 * c0); - auto temp2 = _u({col1,col1+k+1, col2+1,col2+2}, true); - temp2.assign(q1 * (-s0)); - auto temp3 = _u({col1+k+1,col1+n+1, col1, col1+1}, true); - temp3.assign(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true) * s0); - auto temp4 =_u({col1+k+1,col1+n+1, col2+1,col2+2}, true); - temp4 *= c0; + _u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0); + _u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0)); + _u({col1+k+1,col1+n+1, col1, col1+1}, true).assign(static_cast(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true)) * s0); + _u({col1+k+1,col1+n+1, col2+1,col2+2}, true) *= c0; } else { @@ -841,8 +833,7 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true); blockM = 0.f; auto diag = blockM.diagonal('c'); - diag->assign(singVals); - delete diag; + diag.assign(singVals); } ////////////////////////////////////////////////////////////////////////// @@ -958,16 +949,16 @@ static void svd_(const NDArray* x, const std::vector& outArrs, const b ResultSet* listU(nullptr), *listV(nullptr); if(calcUV) { - listU = u->allTensorsAlongDimension({rank-2, rank-1}); - listV = v->allTensorsAlongDimension({rank-2, rank-1}); + listU = new ResultSet(u->allTensorsAlongDimension({rank-2, rank-1})); + listV = new ResultSet(v->allTensorsAlongDimension({rank-2, rank-1})); } - for(int i = 0; i < listX->size(); ++i) { + for(int i = 0; i < listX.size(); ++i) { - // NDArray matrix(x->ordering(), {listX->at(i)->sizeAt(0), listX->at(i)->sizeAt(1)}, block.getContext()); - // matrix.assign(listX->at(i)); - helpers::SVD svdObj(*(listX->at(i)), switchNum, calcUV, calcUV, fullUV); - listS->at(i)->assign(svdObj._s); + // NDArray matrix(x->ordering(), {listX.at(i)->sizeAt(0), listX.at(i)->sizeAt(1)}, block.getContext()); + // matrix.assign(listX.at(i)); + helpers::SVD svdObj(*(listX.at(i)), switchNum, calcUV, calcUV, fullUV); + listS.at(i)->assign(svdObj._s); if(calcUV) { listU->at(i)->assign(svdObj._u); @@ -975,9 +966,6 @@ static void svd_(const NDArray* x, const std::vector& outArrs, const b } } - delete listX; - delete listS; - if(calcUV) { delete listU; delete listV; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp index 0fc6eea0b..481575297 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp @@ -30,7 +30,7 @@ namespace nd4j { return BitwiseUtils::flip_bits(_x); }; - in.applyLambda(lambda, &out); + in.applyLambda(lambda, out); } void __toggle_bits(nd4j::LaunchContext * context, NDArray& in, NDArray& out) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp index ea2fb348a..ea5e90cd8 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp @@ -39,7 +39,7 @@ template static void triuBP_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { auto dOdI = NDArray(&gradO); // dO/dI - const_cast(input).fillAsTriangular(0, diagonal, dOdI.sizeAt(-1), 'b', &dOdI); + const_cast(input).fillAsTriangular(0, diagonal, dOdI.sizeAt(-1), dOdI, 'b'); int dLen = dOdI.lengthOf(); auto func = PRAGMA_THREADS_FOR { @@ -66,11 +66,9 @@ static void trace_(const NDArray& input, NDArray& output) { auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) - output.p(i, setOfSubArrs->at(i)->getTrace()); + output.p(i, setOfSubArrs.at(i)->getTrace()); }; - samediff::Threads::parallel_for(func, 0, setOfSubArrs->size()); - - delete setOfSubArrs; + samediff::Threads::parallel_for(func, 0, setOfSubArrs.size()); } void trace(nd4j::LaunchContext * context, const NDArray& input, NDArray& output) { @@ -137,7 +135,7 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerato if(i == r) continue; - subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r)); + subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); } } else { @@ -149,20 +147,18 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerato //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) for(int i = firstDim - 1; i > 0; --i) { int r = rng.relativeInt(i) % i; - subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r])); + subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); if(r == 0) isZeroShuffled = true; if(i == r) continue; - subArrsListOut->at(r)->assign(subArrsListIn->at(indices[i])); + subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); math::nd4j_swap(indices[i], indices[r]); } if(!isZeroShuffled) - subArrsListOut->at(0)->assign(subArrsListIn->at(0)); - delete subArrsListOut; + subArrsListOut.at(0)->assign(subArrsListIn.at(0)); } rng.rewindH(firstDim-1); - delete subArrsListIn; } } @@ -715,12 +711,10 @@ void eye(nd4j::LaunchContext * context, NDArray& output) { auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) - arrs->at(i)->setIdentity(); + arrs.at(i)->setIdentity(); }; - samediff::Threads::parallel_tad(func, 0, arrs->size()); - - delete arrs; + samediff::Threads::parallel_tad(func, 0, arrs.size()); } ////////////////////////////////////////////////////////////////////////// @@ -752,25 +746,25 @@ void scatterUpdate(nd4j::LaunchContext * context, NDArray& input, NDArray& updat switch (opCode) { case 0: - inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); break; case 1: - inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr); break; case 2: - inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr); break; case 3: - inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr); break; case 4: - inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr); break; case 5: - inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr); break; case 6: - inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr); break; default: continue; @@ -917,7 +911,7 @@ template static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { const int rank = input.rankOf(); - const auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions); + const auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); const T normActual = norm2.e(0); const T normClip = clipNorm.e(0); @@ -937,12 +931,10 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& for (auto i = start; i < stop; i += increment) { const T iNormActual = norm2.e(i); if (iNormActual > normClip) - *listOfInSubArrs->at(i) *= normClip / iNormActual; + *listOfInSubArrs.at(i) *= normClip / iNormActual; } }; - samediff::Threads::parallel_tad(func, 0, listOfInSubArrs->size()); - - delete listOfInSubArrs; + samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); } } else { @@ -961,8 +953,8 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { - auto inputSubArr = listOfInSubArrs->at(i); - auto outputSubArr = listOfOutSubArrs->at(i); + auto inputSubArr = listOfInSubArrs.at(i); + auto outputSubArr = listOfOutSubArrs.at(i); outputSubArr->assign(inputSubArr); const T iNormActual = norm2.e(i); @@ -971,10 +963,7 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& *outputSubArr *= clipNorm / iNormActual; } }; - samediff::Threads::parallel_tad(func, 0, listOfInSubArrs->size()); - - delete listOfInSubArrs; - delete listOfOutSubArrs; + samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); } } } @@ -1021,7 +1010,7 @@ void clipByNorm(nd4j::LaunchContext * context, NDArray& input, NDArray& output, else { auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - input->applyLambda(lambda, output); + input->applyLambda(lambda, *output); } } } @@ -1037,7 +1026,7 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g const int rank = input.rankOf(); - auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions); + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); if(norm2.lengthOf() == 1) { @@ -1055,16 +1044,16 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); }; - (const_cast(input)).applyPairwiseLambda(const_cast(&gradO), lambda, &gradI); + (const_cast(input)).applyPairwiseLambda(const_cast(gradO), lambda, gradI); } else gradI.assign(gradO); } else { - const auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions}); - const auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions}); - const auto inputSubArrs = input.allTensorsAlongDimension({dimensions}); + auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions}); + auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions}); + auto inputSubArrs = input.allTensorsAlongDimension({dimensions}); auto cn = clipNorm.e(0); @@ -1072,11 +1061,11 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g for (auto i = start; i < stop; i += increment) { T N = norm2.e(i); - auto gradOSubArr = gradOSubArrs->at(i); - auto gradISubArr = gradISubArrs->at(i); + auto gradOSubArr = gradOSubArrs.at(i); + auto gradISubArr = gradISubArrs.at(i); if (N > cn) { - auto inputSubArr = inputSubArrs->at(i); + auto inputSubArr = inputSubArrs.at(i); const T sumOfProd = (*inputSubArr * *gradOSubArr).reduceNumber(reduce::Sum).e(0); // reduce to scalar const T factor1 = static_cast(1.f) / N; const T factor3 = factor1 / (N * N); // 1 / (N*N*N) @@ -1085,16 +1074,12 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); }; - inputSubArr->applyPairwiseLambda(gradOSubArr, lambda, gradISubArr); + inputSubArr->applyPairwiseLambda(*gradOSubArr, lambda, *gradISubArr); } else gradISubArr->assign(gradOSubArr); } }; - samediff::Threads::parallel_tad(func, 0, gradISubArrs->size()); - - delete gradISubArrs; - delete gradOSubArrs; - delete inputSubArrs; + samediff::Threads::parallel_tad(func, 0, gradISubArrs.size()); } } @@ -1120,25 +1105,24 @@ static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector(lambda, &output); + input.applyLambda(lambda, output); } } else { // along dimension - auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions, false); + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false); if (!isInplace) output.assign(input); auto tads = output.allTensorsAlongDimension(dimensions); // TODO: make this CUDA-compliant somehow - for (int e = 0; e < tads->size(); e++) { - T n2 = norm2.e(e) / tads->at(e)->lengthOf(); + for (int e = 0; e < tads.size(); e++) { + T n2 = norm2.e(e) / tads.at(e)->lengthOf(); const T factor = cn / n2; if (n2 > cn) { auto lambda = LAMBDA_T(_x, factor) {return _x * factor;}; - tads->at(e)->applyLambda(lambda, &output); + tads.at(e)->applyLambda(lambda, output); } } - delete tads; } } @@ -1164,7 +1148,7 @@ static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector(routine, &output); + input.applyLambda(routine, output); } void clipByValue(nd4j::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp new file mode 100644 index 000000000..ab409a0c6 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp @@ -0,0 +1,135 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// +#include +#include +#include +#include "../triangular_solve.h" + +namespace nd4j { +namespace ops { +namespace helpers { + /* + * lower triangular process for system of linear equations + * x_1 = b_1/a_1,1 + * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 + * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 + * ... + * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M + * + * output == x + * a == leftInput + * b == rightInput + * + * */ + template + static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { + auto rows = leftInput->rows(); + //output->t(0,0) = rightInput->t(0,0) / leftInput->t(0,0); + for (auto r = 0; r < rows; r++) { + auto sum = rightInput->t(r, 0); + for (auto c = 0; c < r; c++) { + sum -= leftInput->t(r,c) * output->t(c, 0); + } + output->t(r, 0) = sum / leftInput->t(r, r); + } + } + + /* + * upper triangular process for system of linear equations + * x_M = b_M/a_M,M + * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 + * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 + * ... + * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 + * + * output == x + * a == leftInput + * b == rightInput + * + * */ + + template + static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { + auto rows = leftInput->rows(); + + for (auto r = rows; r > 0; r--) { + auto sum = rightInput->t(r - 1, 0); + for (auto c = r; c < rows; c++) { + sum -= leftInput->t(r - 1, c) * output->t(c, 0); + } + output->t(r - 1, 0) = sum / leftInput->t(r - 1, r - 1); + } + } + + template + static int triangularSolveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { + auto leftPart = leftInput->allTensorsAlongDimension({-2, -1}); + auto rightPart = rightInput->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + + auto batchLoop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + if (lower) { + lowerTriangularSolve(context, leftPart[i], rightPart[i], adjoint, outputPart[i]); + } else { + upperTriangularSolve(context, leftPart[i], rightPart[i], adjoint, outputPart[i]); + } + } + }; + + samediff::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1); + + return Status::OK(); + + } + template + static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { + auto inputPart = input->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto batchLoop = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch += increment) { + if (!lower) { + for (auto r = 0; r < input->rows(); r++) { + for (auto c = 0; c <= r; c++) { + outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r); + } + } + } else { + for (auto r = 0; r < input->rows(); r++) { + for (auto c = r; c < input->columns(); c++) { + outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r); + } + } + } + } + }; + samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); + } + + int triangularSolveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE); + } + + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, (context, input, lower, output), FLOAT_NATIVE); + } +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cross.h b/libnd4j/include/ops/declarable/helpers/cross.h index d087a4849..31b386e7e 100644 --- a/libnd4j/include/ops/declarable/helpers/cross.h +++ b/libnd4j/include/ops/declarable/helpers/cross.h @@ -65,23 +65,19 @@ void FORCEINLINE cross(nd4j::LaunchContext * context, NDArray *a, NDArray *b, ND auto tadsB = b_.allTensorsAlongDimension({1}); auto tadsO = o_.allTensorsAlongDimension({1}); - int tads = tadsA->size(); + int tads = tadsA.size(); auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - auto a_ = tadsA->at(e); - auto b_ = tadsB->at(e); - auto o_ = tadsO->at(e); + auto a_ = tadsA.at(e); + auto b_ = tadsB.at(e); + auto o_ = tadsO.at(e); helpers::cross(context, a_, b_, o_); } }; samediff::Threads::parallel_tad(func, 0, tads); - - delete tadsA; - delete tadsB; - delete tadsO; } void weightedCrossEntropyWithLogitsFunctor(nd4j::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu b/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu index fabe6800a..4c746f244 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu @@ -244,7 +244,7 @@ namespace helpers { return res; }; - input->applyTriplewiseLambda(gradX, epsilon, gainsInternal, output); + input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu index 21b2eecd4..7bddb00fe 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu @@ -40,15 +40,11 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo, const auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); - __shared__ Nd4jLong xzLen, totalThreads, *sharedMem; + __shared__ Nd4jLong xzLen; __shared__ int xzRank, yRank; if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - xzLen = shape::length(xShapeInfo); - totalThreads = gridDim.x * blockDim.x; xzRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); @@ -56,18 +52,15 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo, __syncthreads(); const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - Nd4jLong* coords = sharedMem + threadIdx.x * xzRank; - - for (int i = tid; i < xzLen; i += totalThreads) { + Nd4jLong coords[MAX_RANK]; + for (int i = tid; i < xzLen; i += blockDim.x * gridDim.x) { shape::index2coords(i, xShapeInfo, coords); const auto xzOffset = shape::getOffset(xShapeInfo, coords); - const auto xVal = x[xzOffset]; if(xVal < 0) { - for (uint j = 0; j < yRank; ++j) if(yShapeInfo[j + 1] == 1) coords[j + 1] = 0; @@ -82,7 +75,6 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo, /////////////////////////////////////////////////////////////////// template linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) { - preluCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz); } @@ -91,9 +83,9 @@ void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& a PointersManager manager(context, "prelu"); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; const auto xType = input.dataType(); const auto yType = alpha.dataType(); @@ -119,13 +111,10 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI auto dLdI = reinterpret_cast(vdLdI); auto dLdA = reinterpret_cast(vdLdA); - __shared__ Nd4jLong inLen, totalThreads, *sharedMem; + __shared__ Nd4jLong inLen, totalThreads; __shared__ int inRank, alphaRank; if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - inLen = shape::length(inShapeInfo); totalThreads = gridDim.x * blockDim.x; @@ -135,10 +124,9 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI __syncthreads(); const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - Nd4jLong* coords = sharedMem + threadIdx.x * inRank; + Nd4jLong coords[MAX_RANK]; for (int i = tid; i < inLen; i += totalThreads) { - shape::index2coords(i, inShapeInfo, coords); const auto inOffset = shape::getOffset(inShapeInfo, coords); @@ -175,14 +163,13 @@ __host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int thr ////////////////////////////////////////////////////////////////////////// void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) { - - dLdA.nullify(); + dLdA.nullify(); PointersManager manager(context, "preluBP"); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; const auto xType = input.dataType(); const auto zType = alpha.dataType(); @@ -345,9 +332,9 @@ void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& outpu BUILD_SINGLE_SELECTOR(input.dataType(), softMaxCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets()), FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&input}); - // auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); + // auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); // (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - // auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + // auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); // output /= sumAlongDim; // input.tickReadDevice(); } @@ -463,11 +450,11 @@ void logSoftmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& ou } else { - auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); + (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily + auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= sumAlongDim; - output.applyTransform(transform::Log); + output.applyTransform(transform::Log, output); input.tickReadDevice(); } @@ -580,9 +567,9 @@ void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDAr } else { - auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); + (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily + auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= sumAlongDim; output *= (1.f - output); // derivative input.tickReadDevice(); @@ -600,7 +587,7 @@ void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDAr auto routine = LAMBDA_T(_x, threshold) { return _x > (T)threshold ? _x: (T)0.f; }; - const_cast(input).applyLambda(routine, &output); + const_cast(input).applyLambda(routine, output); } void thresholdRelu(nd4j::LaunchContext * context, NDArray const& input, double threshold, NDArray& output) { @@ -611,7 +598,7 @@ void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDAr linkage void thresholdReluDerivative_(NDArray* input, double theta, NDArray* dLdO, NDArray* output) { auto derivative = LAMBDA_TT(_x, grO, theta) {if (_x > theta) return grO; else return static_cast(0); }; - input->applyPairwiseLambda(dLdO, derivative, output); + input->applyPairwiseLambda(*dLdO, derivative, *output); } void thresholdReluDerivative(nd4j::LaunchContext * context, NDArray* input, double threshold, NDArray* dLdO, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu b/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu index 5b52d1b0b..5712887da 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu @@ -58,11 +58,11 @@ static void _CUDA_G adjustHueCuda(const void* vx, const Nd4jLong* xShapeInfo, co rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); - h += delta * 360; - if(h > 360) - h -= 360; + h += delta ; + if(h > 1) + h -= 1; else if(h < 0) - h += 360; + h += 1; hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu index 450ac08cc..99fbd33a8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu @@ -45,21 +45,21 @@ void bgemm(const std::vector& vA, const std::vector& vB, std for(int i = 0; i < bS; ++i) { if(vA[i]->ews() != 1) { - pA[i] = vA[i]->dup('f'); + pA[i] = new NDArray(vA[i]->dup('f')); toDelete.emplace_back(pA[i]); } else pA[i] = vA[i]; if(vB[i]->ews() != 1) { - pB[i] = vB[i]->dup('f'); + pB[i] = new NDArray(vB[i]->dup('f')); toDelete.emplace_back(pB[i]); } else pB[i] = vB[i]; if(vC[i]->ews() != 1) { - pC[i] = vC[i]->dup('f'); + pC[i] = new NDArray(vC[i]->dup('f')); toDelete.emplace_back(pC[i]); } else diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu index d9188e3a8..eedbe1fdf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu @@ -31,66 +31,66 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////// -template -__global__ static void batchnormCuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vMean, const Nd4jLong* meanShapeInfo, - const void* vVariance, const Nd4jLong* varianceShapeInfo, - const void* vGamma, const Nd4jLong* gammaShapeInfo, - const void* vBeta, const Nd4jLong* betaShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, - const T epsilon) { +// template +// __global__ static void batchnormCuda(const void* vx, const Nd4jLong* xShapeInfo, +// const void* vMean, const Nd4jLong* meanShapeInfo, +// const void* vVariance, const Nd4jLong* varianceShapeInfo, +// const void* vGamma, const Nd4jLong* gammaShapeInfo, +// const void* vBeta, const Nd4jLong* betaShapeInfo, +// void* vz, const Nd4jLong* zShapeInfo, +// const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, +// const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, +// const T epsilon) { - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - const auto mean = reinterpret_cast(vMean); - const auto variance = reinterpret_cast(vVariance); - const auto gamma = reinterpret_cast(vGamma); - const auto beta = reinterpret_cast(vBeta); +// const auto x = reinterpret_cast(vx); +// auto z = reinterpret_cast(vz); +// const auto mean = reinterpret_cast(vMean); +// const auto variance = reinterpret_cast(vVariance); +// const auto gamma = reinterpret_cast(vGamma); +// const auto beta = reinterpret_cast(vBeta); - // maxRank = xRank = zRank, minRank = meanRank = varianceRank = gammaRank = betaRank - __shared__ Nd4jLong minLen, tadLen, totalThreads; +// // maxRank = xRank = zRank, minRank = meanRank = varianceRank = gammaRank = betaRank +// __shared__ Nd4jLong minLen, tadLen, totalThreads; - if (threadIdx.x == 0) { - totalThreads = gridDim.x * blockDim.x; +// if (threadIdx.x == 0) { +// totalThreads = gridDim.x * blockDim.x; - minLen = shape::length(meanShapeInfo); - tadLen = shape::length(xShapeInfo) / minLen; - } - __syncthreads(); +// minLen = shape::length(meanShapeInfo); +// tadLen = shape::length(xShapeInfo) / minLen; +// } +// __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; +// const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (uint i = tid; i < minLen; i += totalThreads) { +// for (uint i = tid; i < minLen; i += totalThreads) { - const auto meanOffset = shape::getIndexOffset(i, meanShapeInfo); - const auto varianceOffset = shape::getIndexOffset(i, varianceShapeInfo); +// const auto meanOffset = shape::getIndexOffset(i, meanShapeInfo); +// const auto varianceOffset = shape::getIndexOffset(i, varianceShapeInfo); - T sigmaInvGam = 1. / nd4j::math::nd4j_sqrt(variance[varianceOffset] + epsilon); +// T sigmaInvGam = 1. / nd4j::math::nd4j_sqrt(variance[varianceOffset] + epsilon); - if(gamma != nullptr) - sigmaInvGam *= gamma[shape::getIndexOffset(i, gammaShapeInfo)]; +// if(gamma != nullptr) +// sigmaInvGam *= gamma[shape::getIndexOffset(i, gammaShapeInfo)]; - auto betaOffset = 0; - if(beta != nullptr) - betaOffset = shape::getIndexOffset(i, betaShapeInfo); +// auto betaOffset = 0; +// if(beta != nullptr) +// betaOffset = shape::getIndexOffset(i, betaShapeInfo); - const auto xTad = x + xTadOffsets[i]; - auto zTad = z + zTadOffsets[i]; +// const auto xTad = x + xTadOffsets[i]; +// auto zTad = z + zTadOffsets[i]; - for (uint j = 0; j < tadLen; ++j) { +// for (uint j = 0; j < tadLen; ++j) { - const auto xTadOffset = shape::getIndexOffset(j, xTadShapeInfo); - const auto zTadOffset = shape::getIndexOffset(j, zTadShapeInfo); +// const auto xTadOffset = shape::getIndexOffset(j, xTadShapeInfo); +// const auto zTadOffset = shape::getIndexOffset(j, zTadShapeInfo); - zTad[zTadOffset] = (xTad[xTadOffset] - mean[meanOffset]) * sigmaInvGam; +// zTad[zTadOffset] = (xTad[xTadOffset] - mean[meanOffset]) * sigmaInvGam; - if(beta != nullptr) - zTad[zTadOffset] += beta[betaOffset]; - } - } -} +// if(beta != nullptr) +// zTad[zTadOffset] += beta[betaOffset]; +// } +// } +// } ////////////////////////////////////////////////////////////////////////// template @@ -110,13 +110,12 @@ __global__ static void batchnormCuda2(const void* vx, const Nd4jLong* xShapeInfo const auto gamma = reinterpret_cast(vGamma); const auto beta = reinterpret_cast(vBeta); - __shared__ int xRank, minRank; // xRank == zRank. minRank = meanRank = varianceRank = gammaRank = betaRank - __shared__ Nd4jLong xLen, totalThreads, *sharedMem; // xLen = zLen + __shared__ int xRank, minRank; // xRank == zRank, minRank = meanRank = varianceRank = gammaRank = betaRank + __shared__ Nd4jLong xLen, totalThreads; // xLen = zLen if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + totalThreads = gridDim.x * blockDim.x; xLen = shape::length(xShapeInfo); @@ -125,7 +124,8 @@ __global__ static void batchnormCuda2(const void* vx, const Nd4jLong* xShapeInfo } __syncthreads(); - auto coords = sharedMem + threadIdx.x * xRank; + Nd4jLong coords[MAX_RANK]; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (uint i = tid; i < xLen; i += totalThreads) { @@ -166,24 +166,24 @@ __global__ static void batchnormCuda2(const void* vx, const Nd4jLong* xShapeInfo } /////////////////////////////////////////////////////////////////// -template -__host__ static void batchnormCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vMean, const Nd4jLong* meanShapeInfo, - const void* vVariance, const Nd4jLong* varianceShapeInfo, - const void* vGamma, const Nd4jLong* gammaShapeInfo, - const void* vBeta, const Nd4jLong* betaShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, - const double epsilon) { +// template +// __host__ static void batchnormCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, +// const void* vx, const Nd4jLong* xShapeInfo, +// const void* vMean, const Nd4jLong* meanShapeInfo, +// const void* vVariance, const Nd4jLong* varianceShapeInfo, +// const void* vGamma, const Nd4jLong* gammaShapeInfo, +// const void* vBeta, const Nd4jLong* betaShapeInfo, +// void* vz, const Nd4jLong* zShapeInfo, +// const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, +// const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, +// const double epsilon) { - batchnormCuda<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, static_cast(epsilon)); -} +// batchnormCuda<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, static_cast(epsilon)); +// } /////////////////////////////////////////////////////////////////// template -__host__ static void batchnormCudaLauncher2(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, +__host__ static void batchnormCudaLauncher2(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vMean, const Nd4jLong* meanShapeInfo, const void* vVariance, const Nd4jLong* varianceShapeInfo, @@ -193,42 +193,41 @@ __host__ static void batchnormCudaLauncher2(const int blocksPerGrid, const int t const int numDims, const int* dims, const double epsilon) { - batchnormCuda2<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, numDims, dims, static_cast(epsilon)); + batchnormCuda2<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, numDims, dims, static_cast(epsilon)); } ////////////////////////////////////////////////////////////////////////// void batchnorm(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes); + // std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes); + + // auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimsToExclude); + // auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimsToExclude); + + // const int threadsPerBlock = MAX_NUM_THREADS / 2; + // const int blocksPerGrid = (mean->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + // PointersManager manager(input->getContext(), "batchnorm"); + + // NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); + // BUILD_SINGLE_SELECTOR(input->dataType(), batchnormCudaLauncher, (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), mean->getSpecialBuffer(), mean->getSpecialShapeInfo(), variance->getSpecialBuffer(), variance->getSpecialShapeInfo(), gamma ? gamma->getSpecialBuffer() : nullptr, gamma ? gamma->getSpecialShapeInfo() : nullptr, beta ? beta->getSpecialBuffer() : nullptr, beta ? beta->getSpecialShapeInfo() : nullptr, output->specialBuffer(), output->specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), epsilon), FLOAT_TYPES); + // NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); + + // manager.synchronize(); - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimsToExclude); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimsToExclude); const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (mean->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int blocksPerGrid = (input->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; PointersManager manager(input->getContext(), "batchnorm"); + const int* dims = reinterpret_cast(manager.replicatePointer(axes.data(), axes.size() * sizeof(int))); + NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); - BUILD_SINGLE_SELECTOR(input->dataType(), batchnormCudaLauncher, (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), mean->getSpecialBuffer(), mean->getSpecialShapeInfo(), variance->getSpecialBuffer(), variance->getSpecialShapeInfo(), gamma ? gamma->getSpecialBuffer() : nullptr, gamma ? gamma->getSpecialShapeInfo() : nullptr, beta ? beta->getSpecialBuffer() : nullptr, beta ? beta->getSpecialShapeInfo() : nullptr, output->specialBuffer(), output->specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), epsilon), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), batchnormCudaLauncher2, (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), mean->getSpecialBuffer(), mean->getSpecialShapeInfo(), variance->getSpecialBuffer(), variance->getSpecialShapeInfo(), gamma ? gamma->getSpecialBuffer() : nullptr, gamma ? gamma->getSpecialShapeInfo() : nullptr, beta ? beta->getSpecialBuffer() : nullptr, beta ? beta->getSpecialShapeInfo() : nullptr, output->specialBuffer(), output->specialShapeInfo(), axes.size(), dims, epsilon), FLOAT_TYPES); NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); manager.synchronize(); - - - // const int threadsPerBlock = MAX_NUM_THREADS / 4; - // const int blocksPerGrid = (input->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - // const int sharedMem = sizeof(Nd4jLong) * threadsPerBlock * input->rankOf() + 128; - - // PointersManager manager(input->getContext(), "batchnorm"); - - // const int* dims = reinterpret_cast(manager.replicatePointer(axes.data(), axes.size() * sizeof(int))); - - // NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); - // BUILD_SINGLE_SELECTOR(input->dataType(), batchnormCudaLauncher2, (blocksPerGrid, threadsPerBlock, sharedMem, input->getContext()->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), mean->getSpecialBuffer(), mean->getSpecialShapeInfo(), variance->getSpecialBuffer(), variance->getSpecialShapeInfo(), gamma ? gamma->getSpecialBuffer() : nullptr, gamma ? gamma->getSpecialShapeInfo() : nullptr, beta ? beta->getSpecialBuffer() : nullptr, beta ? beta->getSpecialShapeInfo() : nullptr, output->specialBuffer(), output->specialShapeInfo(), axes.size(), dims, epsilon), FLOAT_TYPES); - // NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); - - // manager.synchronize(); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu index 6f9a8c6ab..43c0e4af9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu @@ -39,13 +39,10 @@ template __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) { T* z = reinterpret_cast(vz); - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; + __shared__ Nd4jLong zLen, totalThreads; __shared__ int rank; if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - zLen = shape::length(zShapeInfo); rank = shape::rank(zShapeInfo); totalThreads = gridDim.x * blockDim.x; @@ -54,27 +51,26 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jL const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - if(tid >= zLen) - return; + Nd4jLong coords[MAX_RANK]; - auto coords = sharedMem + threadIdx.x * rank; + for (uint64_t i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); - shape::index2coords(tid, zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + int inArrIdx = 0; + Nd4jLong *xShapeInfo = reinterpret_cast(pxShapeInfo)[inArrIdx]; - int inArrIdx = 0; - Nd4jLong *xShapeInfo = reinterpret_cast(pxShapeInfo)[inArrIdx]; + while (coords[axis] >= xShapeInfo[axis + 1]) { + coords[axis] -= xShapeInfo[axis + 1]; + xShapeInfo = reinterpret_cast(pxShapeInfo)[++inArrIdx]; + } - while(coords[axis] >= xShapeInfo[axis + 1]) { - coords[axis] -= xShapeInfo[axis + 1]; - xShapeInfo = reinterpret_cast(pxShapeInfo)[++inArrIdx]; + const auto *x = reinterpret_cast(reinterpret_cast(pVx)[inArrIdx]); + const auto xOffset = shape::getOffset(xShapeInfo, coords); + + z[zOffset] = x[xOffset]; } - - const auto* x = reinterpret_cast(reinterpret_cast(pVx)[inArrIdx]); - const auto xOffset = shape::getOffset(xShapeInfo, coords); - - z[zOffset] = x[xOffset]; } /////////////////////////////////////////////////////////////////// @@ -89,9 +85,9 @@ BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid ////////////////////////////////////////////////////////////////////////// void concat(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128; + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; const int numOfArrs = inArrs.size(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 6b86ce302..4f77b2e7c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -1228,7 +1228,7 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N NDArray* gradBR = gradB; if(gradB->rankOf() == 2) gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradBR, gradOaxesForDot); // sum over bS, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW if(gradBR != gradB) delete gradBR; } @@ -1310,7 +1310,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con NDArray* gradBR = gradB; if(gradB->rankOf() == 2) gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW if(gradBR != gradB) delete gradBR; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index aa47e3e88..cbdff509d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -49,8 +49,8 @@ namespace helpers { } return nd4j::math::nd4j_round(zeroPointFromMin); }(); - *nudgedMin = (quantMinF - nudgedZeroPoint) * (*scale); - *nudgedMax = (quantMaxF - nudgedZeroPoint) * (*scale); + *nudgedMax = (quantMaxF - static_cast(nudgedZeroPoint)) * (*scale); + *nudgedMin = (quantMinF - static_cast(nudgedZeroPoint)) * (*scale); } template @@ -75,7 +75,7 @@ namespace helpers { return (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); }; - input->applyLambda(wiseMinMaxAndSoOn, output); + input->applyLambda(wiseMinMaxAndSoOn, *output); } template diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu index 9d0e5e55b..a12b43973 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu @@ -31,7 +31,7 @@ void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step return _x - (_y * weight); }; - input->applyPairwiseLambda(step, lambda, output); + input->applyPairwiseLambda(*step, lambda, *output); } void applyGradientDescent(nd4j::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gru.cu b/libnd4j/include/ops/declarable/helpers/cuda/gru.cu index cbbdf1439..82ab9d764 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gru.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gru.cu @@ -77,15 +77,15 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa // reset gate r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r->applyTransform(transform::Sigmoid); + r->applyTransform(transform::Sigmoid, *r); // update gate u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u->applyTransform(transform::Sigmoid); + u->applyTransform(transform::Sigmoid, *u); // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c->applyTransform(transform::Tanh); + c->applyTransform(transform::Tanh, *c); NDArray temp = 1.f - *c * *c; @@ -231,15 +231,15 @@ void gruCellBP(nd4j::LaunchContext* context, // reset gate NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid); + r.applyTransform(transform::Sigmoid, r); // update gate NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid); + u.applyTransform(transform::Sigmoid, u); // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh); + c.applyTransform(transform::Tanh, c); // h = (1 - u) * c + u * hPrev @@ -352,10 +352,10 @@ void gruCellBP(nd4j::LaunchContext* context, dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0})); // [nU] + dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] + dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU] + dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index ab3a96801..94df35964 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -29,7 +30,7 @@ limitations under the License. ==============================================================================*/ // -// @author sgazeos@gmail.com +// @author George A. Shulinok // #include @@ -639,7 +640,7 @@ namespace helpers { if (err != 0) { cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot allocated device memory for interpolate calculator", err); } - err = cudaMemcpy(pCalcD, &calc, sizeof(CachedInterpolationCalculator), cudaMemcpyHostToDevice); + err = cudaMemcpyAsync(pCalcD, &calc, sizeof(CachedInterpolationCalculator), cudaMemcpyHostToDevice, *stream); if (err != 0) { cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot set up device memory for interpolate calculator", err); } @@ -689,11 +690,17 @@ namespace helpers { } template - static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, float* cachedValue, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) { + static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) { // auto numChannels = pResizerState->channels; + for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) { auto pInput = inputPtr + b * inBatchWidth; + float* cachedValue; for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; y += blockDim.x) { + if (threadIdx.x == 0) { + extern __shared__ char sharedChar[]; + cachedValue = reinterpret_cast(sharedChar); + } auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * pResizerState->channels; auto pOutput = &outputPtr[pos]; struct WeightsAndIndices yWai; @@ -841,25 +848,25 @@ namespace helpers { if (err != 0) { throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot allocate memory for resizerState", err); } - err = cudaMemcpy(resizerStateD, &resizerState, sizeof(ImageResizerState), cudaMemcpyHostToDevice); + err = cudaMemcpyAsync(resizerStateD, &resizerState, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream); if (err != 0) { throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot set up memory for resizerState", err); } - float* cachedValue = nullptr; - size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels); - if (cachedSize) { - err = cudaMalloc(reinterpret_cast(&cachedValue), cachedSize); - if (err != 0) { - throw cuda_exception::build( - "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err); - } - err = cudaMemset(cachedValue, 0, cachedSize); - if (err != 0) { - throw cuda_exception::build( - "helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err); - } - } +// float* cachedValue = nullptr; +// size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels); +// if (cachedSize) { +// err = cudaMalloc(reinterpret_cast(&cachedValue), cachedSize); +// if (err != 0) { +// throw cuda_exception::build( +// "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err); +// } +// err = cudaMemset(cachedValue, 0, cachedSize); +// if (err != 0) { +// throw cuda_exception::build( +// "helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err); +// } +// } WeightsAndIndices* xWais; //(resizerState.outWidth); err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth); @@ -878,7 +885,7 @@ namespace helpers { } const T* pInput = image->getDataBuffer()->specialAsT(); float* pOutput = output->dataBuffer()->specialAsT(); //_data.data(); - bicubicInterpolateWithCachingKernel<<<128, 1, 512, *stream>>>(coeffsTable, cachedValue, pInput, + bicubicInterpolateWithCachingKernel<<<128, 1, 512, *stream>>>(coeffsTable, pInput, resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput); err = cudaStreamSynchronize(*stream); if (err != 0) { @@ -889,11 +896,11 @@ namespace helpers { if (err != 0) { throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for resizerState", err); } - if (cachedSize) - err = cudaFree(cachedValue); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err); - } +// if (cachedSize) +// err = cudaFree(cachedValue); +// if (err != 0) { +// throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err); +// } err = cudaFree(xWais); if (err != 0) { @@ -921,6 +928,227 @@ namespace helpers { BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctor_, (nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool preserveAspectRatio, bool antialias, NDArray* output), NUMERIC_TYPES); // ------------------------------------------------------------------------------------------------------------------ // + struct CachedInterpolation { + Nd4jLong start; + Nd4jLong end; + float startScale; + float endMinusOneScale; + bool needsBounding; + }; + + static __global__ void fillInterpolationCache(CachedInterpolation* xCached, Nd4jLong cacheLen, Nd4jLong inWidth, float widthScale) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto increment = blockDim.x * gridDim.x; + + for (auto x = start; x < cacheLen; x += increment) { + auto& xCache = xCached[x]; + const float inX = x * widthScale; + const float inX1 = (x + 1) * widthScale; + + Nd4jLong v = math::nd4j_floor(inX); + xCache.start = v; + xCache.startScale = v < inX ? (v + 1 > inX1 ? widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v : 1.f); + v = math::nd4j_ceil(inX1); + xCache.end = v--; + xCache.endMinusOneScale = v < inX ? (v + 1 > inX1 ? widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v : 1.f); + xCache.needsBounding = bound(xCache.start, inWidth) != xCache.start || bound(xCache.end - 1, inWidth) != (xCache.end - 1); + } + } + +// ------------------------------------------------------------------------------------------------------------------ // + template + struct ScaleCache { + float yScale; + T const* yPtr; + }; + + // Computes the sum of all x values defined by taken across + // the y offsets and scales defined by y_ptrs and y_scales, for channel c. + // + // Note that is a template parameter to avoid a performance + // penalty from dynamically checking it. + template + static __device__ void computePatchSumOf3Channels(float scale, + const ImageResizerState& st, + ScaleCache const* yScaleCache, + Nd4jLong ptrsLen, + const CachedInterpolation& xCache, + float* outputPtr) { + + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + float sum_0 = 0; + float sum_1 = 0; + float sum_2 = 0; + for (int i = 0; i < ptrsLen; ++i) { + const T* ptr = yScaleCache[i].yPtr; + float scaleX = xCache.startScale; + Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth); + float sum_y_0 = static_cast(ptr[offset + 0]) * scaleX; + float sum_y_1 = static_cast(ptr[offset + 1]) * scaleX; + float sum_y_2 = static_cast(ptr[offset + 2]) * scaleX; + + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]); + sum_y_1 += static_cast(ptr[offset + 1]); + sum_y_2 += static_cast(ptr[offset + 2]); + } + scaleX = xCache.endMinusOneScale; + offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]) * scaleX; + sum_y_1 += static_cast(ptr[offset + 1]) * scaleX; + sum_y_2 += static_cast(ptr[offset + 2]) * scaleX; + } + sum_0 += sum_y_0 * yScaleCache[i].yScale; + sum_1 += sum_y_1 * yScaleCache[i].yScale; + sum_2 += sum_y_2 * yScaleCache[i].yScale; + } + + outputPtr[0] = sum_0 * scale; + outputPtr[1] = sum_1 * scale; + outputPtr[2] = sum_2 * scale; + } + + // Computes the sum of all x values defined by taken across + // the y offsets and scales defined by y_ptrs and y_scales, for channel c. + // + // Note that is a template parameter to avoid a performance + // penalty from dynamically checking it. + template + static __device__ void computePatchSum(float scale, const ImageResizerState& st, + ScaleCache const* yScaleCache, Nd4jLong ptrsLen, + const CachedInterpolation& xCache, + float* outputPtr) { + + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + const auto numChannels = st.channels; + for (Nd4jLong c = 0; c < numChannels; ++c) { + float sum = 0; + for (int i = 0; i < ptrsLen; ++i) { + T const* ptr = yScaleCache[i].yPtr; + float scaleX = xCache.startScale; + float sumY = static_cast(ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * scaleX; + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + sumY += static_cast( + ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]); + } + scaleX = xCache.endMinusOneScale; + sumY += static_cast(ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + c]) * scaleX; + } + sum += sumY * yScaleCache[i].yScale; + } + outputPtr[c] = sum * scale; + } + } + + template + static __global__ void resizeAreaKernel(ImageResizerState const* pSt, CachedInterpolation const* caches, float scale, + T const* inputPtr, Nd4jLong* inputShape, float* outputPtr, Nd4jLong* outputShape, ScaleCache* cachePool) { //batch * outWidth * outHeight + + for (auto batch = blockIdx.x; batch < pSt->batchSize; batch += gridDim.x) { + for (auto y = threadIdx.x; y < pSt->outHeight; y += blockDim.x) { + const float inY = y * pSt->heightScale; + const float inY1 = (y + 1) * pSt->heightScale; + // The start and end height indices of all the cells that could + // contribute to the target cell. + const Nd4jLong yStart = math::nd4j_floor(inY); + const Nd4jLong yEnd = math::nd4j_ceil(inY1); + auto scalesDim = yEnd - yStart; + auto yScaleCache = cachePool + (batch * pSt->outWidth + y) * scalesDim * sizeof(ScaleCache); + + //auto startPtr = sharedPtr + y * scalesDim * sizeof(float); + //float* yScales = yScalesShare + y * sizeof(float) * scalesDim;//reinterpret_cast(startPtr); //shared + y * scalesDim * y + scalesDim * sizeof(T const *) [scalesDim]; + //T const** yPtrs = yPtrsShare + y * sizeof(T const*) * scalesDim; //[scalesDim]; + //yPtrs = reinterpret_cast(sharedBuf); + float* output = outputPtr + (batch * pSt->outHeight + y) * pSt->channels * pSt->outWidth; + //int k = 0; + for (Nd4jLong i = yStart, k = 0; i < yEnd; ++i, ++k) { + float scaleY; + if (i < inY) { + scaleY = (i + 1 > inY1 ? pSt->heightScale : i + 1 - inY); + } else { + scaleY = (i + 1 > inY1 ? inY1 - i : 1.0); + } + yScaleCache[k].yScale = scaleY; + yScaleCache[k].yPtr = inputPtr + (batch * pSt->inHeight * pSt->inWidth * pSt->channels + bound(i, pSt->inHeight) * pSt->inWidth * pSt->channels); + } + + if (pSt->channels == 3) { + for (Nd4jLong x = 0; x < pSt->outWidth; ++x) { + const CachedInterpolation& xCache = caches[x]; + computePatchSumOf3Channels(scale, *pSt, yScaleCache, scalesDim, xCache, output); + output += pSt->channels; + } + } else { + for (Nd4jLong x = 0; x < pSt->outWidth; ++x) { + const CachedInterpolation &xCache = caches[x]; + computePatchSum(scale, *pSt, yScaleCache, scalesDim, xCache, output); + output += pSt->channels; + } + } + } + } + } + + template + static void resizeArea(cudaStream_t* stream, ImageResizerState const& st, CachedInterpolation* cache, + NDArray const* input, NDArray* output) { + + T const* inputPtr = reinterpret_cast(input->getSpecialBuffer()); +// float* yScales; +// T const** yPtrs; + float scale = 1.f / (st.heightScale * st.widthScale); + auto outputPtr = reinterpret_cast(output->specialBuffer()); // output is always float. TO DO: provide another float types also with template declaration + ImageResizerState* pSt; + auto err = cudaMalloc(&pSt, sizeof(ImageResizerState)); + err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream); + ScaleCache* cachePool; + err = cudaMalloc(&cachePool, sizeof(ScaleCache) * st.batchSize * st.outWidth * st.outHeight); + resizeAreaKernel<<<128, 4, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr, + output->specialShapeInfo(), cachePool); + err = cudaStreamSynchronize(*stream); + err = cudaFree(cachePool); + err = cudaFree(pSt); + } +// ------------------------------------------------------------------------------------------------------------------ // + template + int resizeAreaFunctor_(nd4j::LaunchContext* context, NDArray const* image, int const width, int const height, + bool const alignCorners, NDArray* output) { + + ImageResizerState st(alignCorners, false); // Create resize info + auto res = st.validateAndCalculateOutputSize(image, width, height); + auto stream = context->getCudaStream(); + if (Status::OK() == res) { + CachedInterpolation* xCached; + //(st.outWidth); + auto err = cudaMalloc(&xCached, sizeof(CachedInterpolation) * st.outWidth); + NDArray::prepareSpecialUse({output}, {image}); + fillInterpolationCache<<<128, 128, 256, *stream>>>(xCached, st.outWidth, st.inWidth, st.widthScale); + resizeArea(stream, st, xCached, image, output); + err = cudaStreamSynchronize(*stream); + err = cudaFree(xCached); + NDArray::registerSpecialUse({output}, {image}); + } + + return res; + } + int resizeAreaFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, + bool const alignCorners, NDArray* output) { + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, (context, image, width, height, alignCorners, output), NUMERIC_TYPES); + } + // ------------------------------------------------------------------------------------------------------------------ // // simplified bicubic resize without antialiasing // @@ -1115,8 +1343,12 @@ namespace helpers { I const* cropSizes = reinterpret_cast(cropSize->getSpecialBuffer()); T* outBuf = reinterpret_cast(crops->specialBuffer()); + int threadsPerBlock = math::nd4j_max(imageHeight * imageWidth, cropHeight * cropWidth); + if(threadsPerBlock > MAX_NUM_THREADS/4) + threadsPerBlock = MAX_NUM_THREADS/4; + NDArray::prepareSpecialUse({crops}, {images, boxes, indices, cropSize}); - cropAndResizeKernel<<>>(imagesBuf, images->getSpecialShapeInfo(), boxesBuf, boxes->getSpecialShapeInfo(), indexBuf, indices->getSpecialShapeInfo(), + cropAndResizeKernel<<>>(imagesBuf, images->getSpecialShapeInfo(), boxesBuf, boxes->getSpecialShapeInfo(), indexBuf, indices->getSpecialShapeInfo(), cropSizes, cropSize->getSpecialShapeInfo(), method, extrapolationVal, outBuf, crops->specialShapeInfo(), numBoxes, cropHeight, cropWidth, batchSize, imageHeight, imageWidth, depth); NDArray::registerSpecialUse({crops}, {images, boxes, indices, cropSize}); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu b/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu new file mode 100644 index 000000000..35393c48c --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu @@ -0,0 +1,425 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + + +namespace nd4j { +namespace ops { +namespace helpers { + + +/////////////////////////////////////////////////////////////////// +template +__global__ void rgbToYuvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + rgbYuv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } + +} + +/////////////////////////////////////////////////////////////////// +template +linkage void rgbToYuvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + + rgbToYuvCuda << > > (vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +} + +/////////////////////////////////////////////////////////////////// +void transformRgbYuv(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { + + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), { dimC }); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), { dimC }); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "yuv_to_rgb"); + + NDArray::prepareSpecialUse({ &output }, { &input }); + BUILD_SINGLE_SELECTOR(input.dataType(), rgbToYuvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), packX.platformOffsets(), output.specialBuffer(), output.specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); + NDArray::registerSpecialUse({ &output }, { &input }); + + manager.synchronize(); +} + +/////////////////////////////////////////////////////////////////// +template +__global__ void yuvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + yuvRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } + +} + +/////////////////////////////////////////////////////////////////// +template +linkage void yuvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + + yuvToRgbCuda << > > (vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +} + +/////////////////////////////////////////////////////////////////// +void transformYuvRgb(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { + + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), { dimC }); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), { dimC }); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "yuv_to_rgb"); + + NDArray::prepareSpecialUse({ &output }, { &input }); + BUILD_SINGLE_SELECTOR(input.dataType(), yuvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), packX.platformOffsets(), output.specialBuffer(), output.specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); + NDArray::registerSpecialUse({ &output }, { &input }); + + manager.synchronize(); +} + +/////////////////////////////////////////////////////////////////// +// for example xShapeInfo = {2,3,4}, zShapeInfo = {2,1,4} +template +__global__ void rgbToGrsCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int dimC) { + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank; // xRank == zRank + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + } + __syncthreads(); + + Nd4jLong* coords = sharedMem + threadIdx.x * rank; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { + + if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && 1 == shape::elementWiseStride(xShapeInfo) && 'c' == shape::order(zShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo)) { + const auto xStep = i*3; + z[i] = 0.2989f * x[xStep] + 0.5870f * x[xStep + 1] + 0.1140f * x[xStep + 2]; + } + else { + + shape::index2coords(i, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto xOffset0 = shape::getOffset(xShapeInfo, coords); + const auto xOffset1 = xOffset0 + shape::stride(xShapeInfo)[dimC]; + const auto xOffset2 = xOffset1 + shape::stride(xShapeInfo)[dimC]; + + z[zOffset] = 0.2989f * x[xOffset0] + 0.5870f * x[xOffset1] + 0.1140f * x[xOffset2]; + } + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void rgbToGrsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int dimC) { + + rgbToGrsCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, dimC); +} + +/////////////////////////////////////////////////////////////////// +void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { + + PointersManager manager(context, "rgbToGrs"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrsCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), dimC), NUMERIC_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); +} + + +/////////////////////////////////////////////////////////////////// +template +static void _CUDA_G rgbToHsvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } +} + +/////////////////////////////////////////////////////////////////// +template +static void _CUDA_G hsvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + hsvToRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } +} + +/////////////////////////////////////////////////////////////////// +template +static _CUDA_H void hsvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + hsvToRgbCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +} + +template +static _CUDA_H void rgbToHsvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + rgbToHsvCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +} + +/////////////////////////////////////////////////////////////////// +void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC}); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "hsv_to_rgb"); + + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), hsvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input}); + + manager.synchronize(); +} + +/////////////////////////////////////////////////////////////////// +void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC}); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "rgb_to_hsv"); + + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), rgbToHsvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input}); + + manager.synchronize(); +} + +template +__global__ void tripleTransformerCuda(const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const int dimC, int mode, uint64_t numTads) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank; // xRank == zRank + + float yiqarr[3][3] = { + { 0.299f, 0.59590059f, 0.2115f }, + { 0.587f, -0.27455667f, -0.52273617f }, + { 0.114f, -0.32134392f, 0.31119955f } + }; + + float rgbarr[3][3] = { + { 1.f, 1.f, 1.f }, + { 0.95598634f, -0.27201283f, -1.10674021f }, + { 0.6208248f, -0.64720424f, 1.70423049f } + }; + + auto tr = mode == 1? yiqarr : rgbarr; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + } + __syncthreads(); + + Nd4jLong* coords = sharedMem + threadIdx.x * rank; + + if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && 1 == shape::elementWiseStride(xShapeInfo) && 'c' == shape::order(zShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo)) { + for (uint64_t f = blockIdx.x * blockDim.x + threadIdx.x; f < zLen / 3; f += gridDim.x * blockDim.x) { + auto i = f * 3; + + auto xi0 = x[i]; + auto xi1 = x[i+1]; + auto xi2 = x[i+2]; + + for (int e = 0; e < 3; e++) + z[i + e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e]; + } + } else { + // TAD based case + const Nd4jLong xDimCstride = shape::stride(xShapeInfo)[dimC]; + const Nd4jLong zDimCstride = shape::stride(zShapeInfo)[dimC]; + + for (uint64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < numTads; i += blockDim.x * gridDim.x) { + const T* xTad = x + xOffsets[i]; + T* zTad = z + zOffsets[i]; + + auto xi0 = xTad[0]; + auto xi1 = xTad[xDimCstride]; + auto xi2 = xTad[xDimCstride * 2]; + + for (int e = 0; e < 3; e++) + zTad[zDimCstride * e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e]; + } + } +} + + +template +static void rgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC); + + NDArray::prepareSpecialUse({output}, {input}); + return tripleTransformerCuda<<<256, 256, 8192, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 1, packZ.numberOfTads()); + NDArray::registerSpecialUse({output}, {input}); +} + +template +FORCEINLINE static void yiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC); + + NDArray::prepareSpecialUse({output}, {input}); + return tripleTransformerCuda<<<256, 256, 8192, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 2, packZ.numberOfTads()); + NDArray::registerSpecialUse({output}, {input}); +} + +void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (context, input, output, dimC), FLOAT_TYPES); +} + +void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (context, input, output, dimC), FLOAT_TYPES); +} + + + + + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu index a5d686dc2..bf6a943fa 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu @@ -48,13 +48,12 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* * In case of vector-input for IsMax, it just turns into IndexReduce call + subsequent filler call */ auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions); - auto targetIdx = indexMax->e(0); + auto targetIdx = indexMax.e(0); dim3 launchDims(128, 512, 1024); BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->specialShapeInfo(), output->lengthOf(), targetIdx), LIBND4J_TYPES); manager.synchronize(); - delete indexMax; } else { Nd4jLong* hostYShapeInfo = nullptr; Nd4jLong* hostTShapeInfo = nullptr; @@ -71,10 +70,8 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* dimension = (int *) manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)); // at this point, all IMax indexes are gathered, and we execute filler - BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, (launchDims, stream, indexMaxArr->specialBuffer(), output->specialBuffer(), output->specialShapeInfo(), packZ.specialShapeInfo(), dimension, dimensionLength, packZ.specialOffsets()), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, (launchDims, stream, indexMaxArr.specialBuffer(), output->specialBuffer(), output->specialShapeInfo(), packZ.specialShapeInfo(), dimension, dimensionLength, packZ.specialOffsets()), LIBND4J_TYPES); manager.synchronize(); - - delete indexMaxArr; } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu index 753c8ae64..a3d24111a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu @@ -33,7 +33,7 @@ namespace nd4j { return x > (T) 0.f ? y : T(0.f); }; - theFirst->applyPairwiseLambda(theSecond, functor, nullptr); + theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); } void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { @@ -46,7 +46,7 @@ namespace nd4j { return x > (T)0.f ? y : T(0.f); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -59,7 +59,7 @@ namespace nd4j { return x > (T)0.f && x < (T)6.f? y : T(0.f); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -75,7 +75,7 @@ namespace nd4j { return x < 0 ? alphaT * y : y; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { @@ -91,7 +91,7 @@ namespace nd4j { return y * nd4j::math::nd4j_eluderivative(x, alphaT); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { @@ -104,7 +104,7 @@ namespace nd4j { return y * simdOps::SELUDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu index 3a09f9a80..afd07cd48 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu @@ -34,7 +34,7 @@ namespace nd4j { return y * ((T)1.0f - (th * th)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void tanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -49,7 +49,7 @@ namespace nd4j { return y * simdOps::HardTanhDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void hardTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -62,7 +62,7 @@ namespace nd4j { return y * simdOps::RationalTanhDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void rationalTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -75,7 +75,7 @@ namespace nd4j { return x > (T) 0.0f ? y * (nd4j::math::nd4j_tanhderivative(x)) : (T) 0.0f; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void rectifiedTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu index fa97a3de2..fb4a94abb 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu @@ -34,7 +34,7 @@ namespace helpers { return y * (3 * x * x); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -50,7 +50,7 @@ namespace helpers { return x > T(0.f)? y : -y; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -66,7 +66,7 @@ namespace helpers { return nd4j::math::nd4j_max(x, (T)0.f) - x * y + nd4j::math::nd4j_log((T)1.f + nd4j::math::nd4j_exp(-nd4j::math::nd4j_abs(x))); }; - logits->applyPairwiseLambda(labels, functor, output); + logits->applyPairwiseLambda(*labels, functor, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,7 +86,7 @@ namespace helpers { return static_cast(1.) - y - e / (static_cast(1.) + e); }; - logits->applyPairwiseLambda(labels, functor, output); + logits->applyPairwiseLambda(*labels, functor, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// void sigmCrossEntropyGrad(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { @@ -104,7 +104,7 @@ namespace helpers { return y * ((T) 1.0f / (ss * ss)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -120,7 +120,7 @@ namespace helpers { return y * (p / (p + 1.)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void softPlusDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -138,7 +138,7 @@ namespace helpers { return y * (s * ((T) 1.0f - s)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void sigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -151,7 +151,7 @@ namespace helpers { return y * simdOps::HardSigmoidDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void hardSigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -162,24 +162,24 @@ namespace helpers { template linkage void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { // reduce along axis with - std::unique_ptr tempInput(input->dup()); - input->applyTransform(transform::Exp, tempInput.get()); + NDArray tempInput = input->dup(); + input->applyTransform(transform::Exp, tempInput); std::vector axisVector; if (axis != nullptr) { axisVector.resize(axis->lengthOf()); for (size_t i = 0; i < axisVector.size(); ++i) axisVector[i] = axis->e(i); } - tempInput->reduceAlongDimension(reduce::Sum, output, axisVector); - output->applyTransform(transform::Log, nullptr, nullptr); + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); } template linkage void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { // reduce along axis with - std::unique_ptr tempInput(input->dup()); - input->applyPairwiseTransform(pairwise::Subtract, subtrah, tempInput.get()); - tempInput->applyTransform(transform::Exp, nullptr, nullptr); + NDArray tempInput = input->dup(); + input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); + tempInput.applyTransform(transform::Exp, tempInput); std::vector axisVector; if (axis != nullptr) { @@ -187,8 +187,8 @@ namespace helpers { for (size_t i = 0; i < axisVector.size(); ++i) axisVector[i] = axis->e(i); } - tempInput->reduceAlongDimension(reduce::Sum, output, axisVector); - output->applyTransform(transform::Log, nullptr, nullptr); + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -223,16 +223,16 @@ namespace helpers { if (weights->isScalar()) { - const_cast(input)->applyPairwiseLambda(const_cast(targets), mainRoutineT1, output); + const_cast(input)->applyPairwiseLambda(const_cast(*targets), mainRoutineT1, *output); } else { std::unique_ptr targetVector(new NDArray(*weights)); - targetVector->applyScalar(scalar::Add, -1.f); + targetVector->applyScalar(scalar::Add, -1.f, *targetVector); std::unique_ptr targetTensor(new NDArray(*targets)); *targetTensor = (*targetVector * *targetTensor) + T(1.f); - const_cast(input)->applyTriplewiseLambda(const_cast(targets), targetTensor.get(), mainRoutineT2, output); + const_cast(input)->applyTriplewiseLambda(const_cast(*targets), *targetTensor.get(), mainRoutineT2, *output); } } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu new file mode 100644 index 000000000..ea9901f0a --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#include +//#include +//#include + +namespace nd4j { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +// calculate digamma function for array elements +template +void lgamma_(NDArray& x, NDArray& z) { + //auto dtype = x.dataType(); + auto lgammaProc = LAMBDA_T(x_, dtype) { + return T(DataTypeUtils::fromT() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log(math::nd4j_gamma(x)); + }; + + x.applyLambda(lgammaProc, z); +} + +void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z) { + + BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES); + + + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu b/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu index 3fc7ef0b7..2204c9189 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu @@ -85,7 +85,7 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h // if clipping value is provided then cell state is clipped by this value prior to the cell output activation if(clippingCellValue > 0.0) - ct->applyScalar(scalar::LstmClip, clippingCellValue); + ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); if(peephole) zot += (*ct) * (*Wc)({{2*nOut, 3*nOut}}); // add peephole connections to output gate zot + ct*Wc @@ -98,7 +98,7 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] // if clipping projection is provided then projected cell output state is clipped by this value if(clippingProjValue != 0.) - ht->applyScalar(scalar::LstmClip, clippingProjValue); + ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); } else ht->assign(&htNoPeepHole); @@ -165,30 +165,30 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast if(forgetBias != 0.0) zf += forgetBias; - zz.applyTransform(transform::Tanh, z); //z = tanh(zz) - zi.applyTransform(transform::Sigmoid, i); //i = sigmoid(zi) - zf.applyTransform(transform::Sigmoid, f); //f = sigmoid(zf); + zz.applyTransform(transform::Tanh, *z); //z = tanh(zz) + zi.applyTransform(transform::Sigmoid, *i); //i = sigmoid(zi) + zf.applyTransform(transform::Sigmoid, *f); //f = sigmoid(zf); //cell state = blockInput .* inputGate + prevCellState .* forgetGate - z->applyPairwiseTransform(pairwise::Multiply, i, c, nullptr); //c = z * i + z->applyPairwiseTransform(pairwise::Multiply, *i, *c); //c = z * i auto temp = (*f) * (*cLast); *c += temp; //c = (i * z) + (zf * (*cLast)) - c->applyTransform(transform::Tanh, h); //h = tanh(c) + c->applyTransform(transform::Tanh, *h); //h = tanh(c) // if clipping value is provided then cell state is clipped by this value prior to the cell output activation if(clippingCellValue > 0.0) - c->applyScalar(scalar::LstmClip, clippingCellValue); + c->applyScalar(scalar::LstmClip, clippingCellValue, *c); if(peephole) { // add peephole connections to output gate zot + ct*Wc auto prod = *c * (*Wco); zo += prod; } - zo.applyTransform(transform::Sigmoid, o); // o = sigmoid(zo) + zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) // current cell output = ot*tanh(ct) - c->applyTransform(transform::Tanh, h); //h = tanh(c) - o->applyPairwiseTransform(pairwise::Multiply, h, y, nullptr); //y = o * h + c->applyTransform(transform::Tanh, *h); //h = tanh(c) + o->applyPairwiseTransform(pairwise::Multiply, *h, *y); //y = o * h } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 568b9a9bc..3e8def28a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -24,6 +24,7 @@ #include #include #include +//#include #include #include @@ -336,7 +337,7 @@ namespace helpers { // // input - A matrix nxn // compound - C matrix L + U - I, or main diagonal and lower - L matrix, from the 2nd diagonal - U matrix - template + template static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { auto stream = context->getCudaStream(); auto n = input->rows(); @@ -383,7 +384,7 @@ namespace helpers { err); } - if (permutation == nullptr) + if (permutation == nullptr) { status = cusolverDnDgetrf( cusolverH, n, @@ -393,9 +394,15 @@ namespace helpers { d_work, nullptr, d_info); + + if (status != CUSOLVER_STATUS_SUCCESS) { + throw cuda_exception::build("helpers::lup_: LU factorization is failed due ", + status); + } + } else { NDArray permutVector('c', {n}, nd4j::DataType::INT32, context); - int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); + int* permutationBuf = permutVector.dataBuffer()->specialAsT(); status = cusolverDnDgetrf( cusolverH, n, @@ -405,9 +412,21 @@ namespace helpers { d_work, permutationBuf, d_info); - fillUpPermutation << < n, n, 1024, *stream >> > - (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); - permutation->tickWriteDevice(); + if (status != CUSOLVER_STATUS_SUCCESS) { + throw cuda_exception::build("helpers::lup_: LU factorization is failed due ", + status); + } + + if (permutation->rankOf() == 2) { + fillUpPermutation <<< n, n, 1024, *stream >>> + (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); + } + else { + permutVector.tickWriteDevice(); + input->tickWriteDevice(); + compound->assign(input); + permutation->assign(permutVector); + } } err = cudaFree(d_work); if (err) { @@ -448,7 +467,7 @@ namespace helpers { nullptr, d_info); else { - NDArray permutVector('c', {n}, nd4j::DataType::INT32, context); + NDArray permutVector('c', {n}, DataType::INT32, context); int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); status = cusolverDnSgetrf( cusolverH, @@ -459,9 +478,16 @@ namespace helpers { d_work, permutationBuf, d_info); - fillUpPermutation <<< n, n, 128, *stream >> > - (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); - permutation->tickWriteDevice(); + if (permutation->rankOf() == 2) { + fillUpPermutation <<< n, n, 128, *stream >>> + (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); + permutation->tickWriteDevice(); + } + else { + input->tickWriteDevice(); + compound->assign(input); + permutation->assign(permutVector); + } } err = cudaFree(d_work); if (err) { @@ -484,8 +510,115 @@ namespace helpers { } // ------------------------------------------------------------------------------------------------------------------ // - BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE); + BUILD_DOUBLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE, INDEXING_TYPES); + template + static __device__ void swapRows(T* matrix, Nd4jLong* shape, Nd4jLong theFirst, Nd4jLong theSecond, Nd4jLong n) { + if (theFirst != theSecond) { + for (auto i = 0; i < n; i++) { + Nd4jLong theFirstPos[] = {theFirst, i}; + Nd4jLong theSecondPos[] = {theSecond, i}; + auto theFirstIndex = shape::getOffset(shape, theFirstPos, 0); + auto theSecondIndex = shape::getOffset(shape, theSecondPos, 0); + math::nd4j_swap(matrix[theFirstIndex], matrix[theSecondIndex]); + } + } + } + + template + static __device__ void processColumns(Nd4jLong currentRow, Nd4jLong rowNum, T* compoundBuf, Nd4jLong* compoundShape) { + Nd4jLong xDiag[] = {currentRow, currentRow}; + auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); + for (auto j = currentRow + 1; j < rowNum; j++) { + Nd4jLong xRow[] = {j, currentRow}; + auto rowIndex = shape::getOffset(compoundShape, xRow, 0); + compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t(i, i); + for (auto k = currentRow + 1; k < rowNum; k++) { + Nd4jLong yRow[] = {j, k}; + Nd4jLong yCol[] = {currentRow, k}; + auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); + auto colIndex = shape::getOffset(compoundShape, yCol, 0); + compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; + } + } + } + + template + __device__ Nd4jLong argmaxCol(Nd4jLong column, T* compoundBuffer, Nd4jLong* compoundShape) { + auto rowNum = shape::sizeAt(compoundShape, 0); + Nd4jLong xInitial[] = {column, column}; + auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); + auto maxValue = T(0); //nd4j::math::nd4j_abs(compoundBuffer[xInitialIndex]); + auto result = -1LL; + + for (auto rowCounter = column; rowCounter < rowNum; rowCounter++) { + Nd4jLong xPos[] = {rowCounter, column}; + auto xIndex = shape::getOffset(compoundShape, xPos, 0); + if (nd4j::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { + maxValue = nd4j::math::nd4j_max(maxValue, nd4j::math::nd4j_abs(compoundBuffer[xIndex])); + result = rowCounter; + } + } + return result; + } + + template + static __device__ int luNN(T* matrix, Nd4jLong* shape, I* permutation, Nd4jLong* permuShape, Nd4jLong n) { + + for (auto i = 0; i < n - 1; i++) { + auto pivotIndex = argmaxCol(i, matrix, shape); + if (pivotIndex < 0) { + return -1;//throw std::runtime_error("helpers::luNN_: input matrix is singular."); + } + math::nd4j_swap(permutation[shape::getIndexOffset(i, permuShape)], permutation[shape::getIndexOffset(pivotIndex, permuShape)]); + swapRows(matrix, shape, (Nd4jLong)i, pivotIndex, n); + + processColumns(i, n, matrix, shape); + } + return 0; + } + + template + static __global__ void luBatchedKernel(T* outputBuf, Nd4jLong* outputShape, I* permutations, Nd4jLong* permuShape, + Nd4jLong* outputTadShape, Nd4jLong* outputTadOffsets, Nd4jLong* permuTadShape, Nd4jLong* permuTadOffsets, + Nd4jLong batchNum) { + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto b = start; b < batchNum; b += step) { + T* matrix = outputBuf + outputTadOffsets[b]; + I* permutation = permutations + permuTadOffsets[b]; + + if (0 != luNN(matrix, outputTadShape, permutation, permuTadShape, shape::length(permuTadShape))) break; + } + } + + template + static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { + auto n = input->sizeAt(-1); + auto stream = context->getCudaStream(); + NDArray iota('c', {n}, permutationVectors->dataType());// = NDArrayFactory::create(); // ('c', {n}); + iota.linspace(0); iota.syncToDevice(); + + output->assign(input); // fill up output tensor with zeros +// output->tickWriteDevice(); + permutationVectors->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), iota, *permutationVectors, true, nullptr); +// permutationVectors->tickWriteDevice(); + auto tads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1}); + auto permutaionTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-1}); + auto batchNum = tads.numberOfTads(); + luBatchedKernel<<>>(reinterpret_cast(output->platformBuffer()), + output->specialShapeInfo(), reinterpret_cast(permutationVectors->platformBuffer()), + permutationVectors->specialShapeInfo(), tads.specialShapeInfo(), tads.specialOffsets(), + permutaionTads.specialShapeInfo(), permutaionTads.specialOffsets(), batchNum); + } + + void lu(LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutations) { + NDArray::prepareSpecialUse({output, permutations}, {input}); + BUILD_DOUBLE_SELECTOR(input->dataType(), permutations->dataType(), lu_, (context, input, output, permutations), FLOAT_NATIVE, INDEXING_TYPES); + NDArray::registerSpecialUse({output, permutations}, {input}); + } // ------------------------------------------------------------------------------------------------------------------ // template static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { @@ -509,7 +642,7 @@ namespace helpers { fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // else // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); - lup_(context, &matrix, nullptr, nullptr); + lup_(context, &matrix, nullptr, nullptr); // else // lup_(context, &matrix, nullptr, nullptr); auto offset = shape::getIndexOffset(e, output->shapeInfo()); @@ -557,7 +690,7 @@ namespace helpers { // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // if (matrix.dataType() == input->dataType()) - lup_(context, &matrix, nullptr, nullptr); + lup_(context, &matrix, nullptr, nullptr); // else // lup_(context, &matrix, nullptr, nullptr); auto offset = shape::getIndexOffset(e, output->shapeInfo()); @@ -638,7 +771,7 @@ namespace helpers { matrix.tickWriteDevice(); //compound.assign(matrix); // if (matrix.dataType() == input->dataType()) - lup_(context, &matrix, nullptr, nullptr); + lup_(context, &matrix, nullptr, nullptr); fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n); lower.tickWriteDevice(); upper.tickWriteDevice(); @@ -705,7 +838,7 @@ namespace helpers { int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { if (!inplace) output->assign(input); - std::unique_ptr tempOutput(output->dup()); + auto tempOutput =output->dup(); cusolverDnHandle_t handle = nullptr; auto n = input->sizeAt(-1); auto n2 = n * n; @@ -715,9 +848,9 @@ namespace helpers { throw cuda_exception::build("helpers::cholesky_: Cannot create solver handle", status); } F **dArrayBatch = nullptr; - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), - {tempOutput->rankOf() - 2, - tempOutput->rankOf() - 1}); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput.getShapeInfo(), + {tempOutput.rankOf() - 2, + tempOutput.rankOf() - 1}); const Nd4jLong batchSize = packX.numberOfTads(); int *dInfoArray = nullptr; auto err = cudaMalloc((void **) &dArrayBatch, sizeof(F *) * batchSize); @@ -731,7 +864,7 @@ namespace helpers { } auto stream = context->getCudaStream(); fillBatchKernel << < 1, batchSize, 128, *stream >> > - (dArrayBatch, reinterpret_cast(tempOutput->specialBuffer()), packX.specialOffsets(), batchSize); + (dArrayBatch, reinterpret_cast(tempOutput.specialBuffer()), packX.specialOffsets(), batchSize); status = cusolverDnSetStream(handle, *stream); if (CUSOLVER_STATUS_SUCCESS != status) { @@ -761,7 +894,7 @@ namespace helpers { throw cuda_exception::build("helpers::cholesky_: Cholesky factorization failed for batch", status); } adjustResultsKernel << < batchSize, n2, 128, *stream >> > - (reinterpret_cast(tempOutput->specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); + (reinterpret_cast(tempOutput.specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); err = cudaFree(dArrayBatch); if (err) { @@ -774,9 +907,9 @@ namespace helpers { } if (!inplace) - output->assign(tempOutput.get()); + output->assign(tempOutput); else - input->assign(tempOutput.get()); + input->assign(tempOutput); NDArray::registerSpecialUse({output}, {input}); return Status::OK(); @@ -844,12 +977,12 @@ namespace helpers { cholesky(context, input, &tempOutput, false); auto outputBuf = output->dataBuffer()->specialAsT(); //reinterpret_cast(output->specialBuffer()); // + e * n2; // + e * n2; - auto inputBuf = tempOutput.dataBuffer()->specialAsT(); //reinterpret_cast(tempOutput->specialBuffer()); + auto inputBuf = tempOutput.dataBuffer()->specialAsT(); //reinterpret_cast(tempOutput.specialBuffer()); output->nullify(); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput.getShapeInfo(), {tempOutput.rankOf() - 2, tempOutput.rankOf() - 1}); - logDetKernel <<< 128, 512, 256, *stream >>>(inputBuf, tempOutput.specialShapeInfo(), + logDetKernel <<<128, 512, 256, *stream>>>(inputBuf, tempOutput.specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo()); output->tickWriteDevice(); @@ -861,6 +994,14 @@ namespace helpers { BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, (context, input, output), FLOAT_NATIVE); } + /* + * lup - batched input, batched outputs + * */ + int lup(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_,(context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES); + return Status::OK(); + } + // BUILD_SINGLE_TEMPLATE(template int logdetFunctor_, // (nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu index ca24b3466..5a95eeb83 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu @@ -59,7 +59,7 @@ namespace helpers { auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1}); - if (listOut->size() != listDiag->size()) { + if (listOut.size() != listDiag.size()) { nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); return ND4J_STATUS_VALIDATION; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu index a2aec252e..b93563de2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu @@ -43,10 +43,10 @@ namespace nd4j { // PWT case case // X gradient - epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); // Y gradient - epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); } else if (y->isScalar()) { T s = y->e(0); @@ -61,7 +61,7 @@ namespace nd4j { else gradY->assign(0.0f); - epsNext->applyPairwiseLambda(x, lambdaS, gradX); + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); } else { // broadcast case @@ -71,8 +71,8 @@ namespace nd4j { auto targetShape = epsNext->getShapeAsVector(); - preX->tileToShape(targetShape); - preY->tileToShape(targetShape); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); @@ -81,22 +81,16 @@ namespace nd4j { auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); - - - delete preX; - delete preY; } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu index 75c73f96b..90142091f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu @@ -43,10 +43,10 @@ namespace nd4j { // PWT case case // X gradient - epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); // Y gradient - epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); } else if (y->isScalar()) { T s = y->e(0); @@ -61,7 +61,7 @@ namespace nd4j { else gradY->assign(0.0f); - epsNext->applyPairwiseLambda(x, lambdaS, gradX); + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); } else { // broadcast case @@ -71,8 +71,8 @@ namespace nd4j { auto targetShape = epsNext->getShapeAsVector(); - preX->tileToShape(targetShape); - preY->tileToShape(targetShape); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); @@ -81,22 +81,16 @@ namespace nd4j { auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); - - - delete preX; - delete preY; } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu b/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu index ccfbbf943..79c9024f5 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu @@ -94,7 +94,7 @@ namespace helpers { shape::checkDimensions(inputRank, axis); auto tempArray = input.dup(input.ordering()); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(tempArray->getShapeInfo(), axis); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions(tempArray.getShapeInfo(), axis); auto tadLength = shape::length(packX.primaryShapeInfo()); @@ -114,11 +114,9 @@ namespace helpers { } position = tadLength - position - 1; - percentileKernel<<<256, 512, 1024, *context->getCudaStream()>>>(tempArray->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), packX.numberOfTads(), tadLength, output.specialBuffer(), output.specialShapeInfo(), output.lengthOf(), position); + percentileKernel<<<256, 512, 1024, *context->getCudaStream()>>>(tempArray.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), packX.numberOfTads(), tadLength, output.specialBuffer(), output.specialShapeInfo(), output.lengthOf(), position); nd4j::DebugHelper::checkErrorCode(context->getCudaStream(), "percentile"); - - delete tempArray; } void percentile(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu b/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu new file mode 100644 index 000000000..88d2b5937 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + static _CUDA_G void print_device(const void *special, const Nd4jLong *shapeInfo) { + auto length = shape::length(shapeInfo); + auto x = reinterpret_cast(special); + + // TODO: add formatting here + printf("["); + + for (uint64_t e = 0; e < length; e++) { + printf("%f", (float) x[shape::getIndexOffset(e, shapeInfo)]); + + if (e < length - 1) + printf(", "); + } + + printf("]\n"); + } + + template + static _CUDA_H void exec_print_device(LaunchContext &ctx, const void *special, const Nd4jLong *shapeInfo) { + print_device<<<1, 1, 1024, *ctx.getCudaStream()>>>(special, shapeInfo); + } + + void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) { + NDArray::prepareSpecialUse({}, {&array}); + + PointersManager pm(&ctx, "print_device"); + BUILD_SINGLE_SELECTOR(array.dataType(), exec_print_device, (ctx, array.getSpecialBuffer(), array.getSpecialShapeInfo()), LIBND4J_TYPES) + pm.synchronize(); + + NDArray::registerSpecialUse({}, {&array}); + } + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu new file mode 100644 index 000000000..29d8924db --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu @@ -0,0 +1,180 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// +#include +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + template + static __global__ void matrixMinorKernel(T* outBuffer, Nd4jLong* outShape, T* inBuffer, Nd4jLong* inShape, Nd4jLong column, Nd4jLong rows, Nd4jLong columns) { +// auto tid = threadIdx.x + blockDim.x * blockIdx.x; +// auto step = blockDim.x * gridDim.x; +// if (threadIdx.x == 0) { +// for (auto i = tid; i < column; i += step) { +// Nd4jLong diagPos[] = {i, i}; +// auto zIndex = shape::getOffset(outShape, diagPos); +// outBuffer[zIndex] = T(1.f); +// } +// } +// __syncthreads(); + + for (auto i = blockIdx.x; i < rows; i += gridDim.x) + for (auto j = threadIdx.x; j < columns; j += blockDim.x) { + Nd4jLong pos[] = {i,j}; + auto zIndex = shape::getOffset(outShape, pos); + auto xIndex = shape::getOffset(inShape, pos); + if (i < column || j < column) { + outBuffer[zIndex] = i != j?T(0.f):T(1.f); + } + else + outBuffer[zIndex] = inBuffer[xIndex]; //m.t(i,j) = in.t(i,j); + } + + + } + + template + NDArray matrixMinor(LaunchContext* context, NDArray& in, Nd4jLong col) { + NDArray m = in.ulike(); + m.setIdentity(); + m({col, m.rows(), col, m.columns()}).assign(in({col, m.rows(), col, m.columns()})); + +// auto stream = context->getCudaStream(); +// matrixMinorKernel<<<128, 128, 256, *stream>>>(m.dataBuffer()->specialAsT(), m.specialShapeInfo(), +// matrixMinorKernel<<<128, 128, 256, *stream>>>(m.dataBuffer()->specialAsT(), m.specialShapeInfo(), +// reinterpret_cast(in.specialBuffer()), in.specialShapeInfo(), col, in.rows(), in.columns()); +// + m.tickWriteDevice(); + return m; + } + +/* m = I - v v^T */ + template + static __global__ void vmulKernel(T* resBuf, Nd4jLong* resShape, T const* vBuff, Nd4jLong const* vShape, Nd4jLong n) { + for (auto i = blockIdx.x; i < n; i += gridDim.x) + for (auto j = threadIdx.x; j < n; j += blockDim.x) { + Nd4jLong posR[] = {i, j}; + auto indexR = shape::getOffset(resShape, posR); + auto indexX = shape::getIndexOffset(i, vShape); + auto indexY = shape::getIndexOffset(j, vShape); + + resBuf[indexR] = T(-2.f) * vBuff[indexX] * vBuff[indexY] + (i != j?T(0.f):T(1.f)); + } + } + + template + NDArray vmul(LaunchContext* context, NDArray const& v, int n) + { + NDArray res('c', {n,n}, v.dataType(), context); // x = matrix_new(n, n); + + auto stream = context->getCudaStream(); + vmulKernel<<<128, 128, 128, *stream>>>(res.dataBuffer()->specialAsT(), res.specialShapeInfo(), + reinterpret_cast(v.getSpecialBuffer()), v.getSpecialShapeInfo(), n); + return res; + } + + template + static bool diagonalIsPositive(NDArray* matrix, Nd4jLong k) { + T hVal; + Nd4jLong pos[] = {k, k}; + auto shift = shape::getOffset(matrix->shapeInfo(), pos); + cudaMemcpy(&hVal, matrix->specialBuffer(), sizeof(T), cudaMemcpyDeviceToHost); + return hVal > T(0.f); + } + + template + void qrSingle(LaunchContext* context, NDArray* matrix, NDArray* Q, NDArray* R, bool const fullMatricies) { + Nd4jLong M = matrix->sizeAt(0); + Nd4jLong N = matrix->sizeAt(1); + auto resQ = fullMatricies?Q->ulike():NDArrayFactory::create(matrix->ordering(), {M,M}, Q->getContext()); + auto resR = fullMatricies?R->ulike():matrix->ulike(); + std::vector q(M); + NDArray z = *matrix; + NDArray e('c', {M}, DataTypeUtils::fromT()); // two internal buffers and scalar for squared norm + for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number + e.nullify(); + z = matrixMinor(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix) + + auto currentColumn = z({0, 0, k, k + 1}); // retrieve k column from z to x buffer + auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0}); + if (diagonalIsPositive(matrix, k)) //matrix->t(k,k) > T(0.f)) // negate on positive matrix diagonal element + norm.applyTransform(transform::Neg, norm); // *= -1.f;//-norm.t(0); + + e.p(k, norm); // e - is filled by 0 vector except diagonal element (filled by 1) + e += currentColumn; // e[i] = x[i] + a * e[i] for each i from 0 to n - 1 + auto normE = e.reduceAlongDimension(reduce::Norm2, {0}); + e /= normE; + q[k] = vmul(context, e, M); + auto qQ = z.ulike(); + MmulHelper::matmul(&q[k], &z, &qQ, false, false); + z = std::move(qQ); + } + resQ.assign(q[0]); // +// MmulHelper::matmul(&q[0], matrix, &resR, false, false); + for (int i = 1; i < N && i < M - 1; i++) { + auto tempResQ = resQ; + MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); + resQ = std::move(tempResQ); + } + MmulHelper::matmul(&resQ, matrix, &resR, false, false); + // resR *= -1.f; + resQ.transposei(); + + if (fullMatricies) { + Q->assign(resQ); + R->assign(resR); + } + else { + Q->assign(resQ({0, 0, 0, N})); + R->assign(resR({0, N, 0, 0})); + } + } + + template + void qr_(LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { + Nd4jLong lastDim = input->rankOf() - 1; + Nd4jLong preLastDim = input->rankOf() - 2; + + NDArray::prepareSpecialUse({outputQ, outputR}, {input}); + ResultSet listOutQ(outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listOutR(outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listInput(input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + auto start = 0; + auto stop = listInput.size(); + auto increment = 1; + + for (auto batch = start; batch < stop; batch += increment) { + //qr here + qrSingle(context, listInput.at(batch), listOutQ.at(batch), listOutR.at(batch), fullMatricies); + } + NDArray::registerSpecialUse({outputQ, outputR}, {input}); + } + + void qr(nd4j::LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { + BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (context, input, outputQ, outputR, fullMatricies), FLOAT_TYPES); + } + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cuda/random.cu b/libnd4j/include/ops/declarable/helpers/cuda/random.cu index 1c28b8f24..1e290bc56 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/random.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/random.cu @@ -27,6 +27,8 @@ #include #include #include +#include +#include namespace nd4j { namespace ops { @@ -82,8 +84,8 @@ namespace helpers { NDArray alphaBroadcasted(broadcasted, alpha->dataType(), true, context); NDArray betaBroadcasted(broadcasted, beta->dataType(), true, context); - copyAlpha = (alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), alpha)); - copyBeta = (betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), beta)); + copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); + copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); copyAlpha->tickWriteDevice(); copyBeta->tickWriteDevice(); } @@ -248,6 +250,116 @@ namespace helpers { BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES); +/////////////////////////////////////////////////////////////////// +// used https://en.wikipedia.org/wiki/Categorical_distribution +// methods: gumbel trick + softmax + argmax +template +__global__ static void fillMultiNomialCuda_(graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong batchValue, + const Nd4jLong numOfSamples, const Nd4jLong numOfClassX, + const Nd4jLong dimA, const X minVal, const X maxVal) { + + + const X* x = reinterpret_cast(vx); + Z* z = reinterpret_cast(vz); + + __shared__ Nd4jLong xDimAstride, zDimAstride, xDimCstride, zDimCstride, dimC; + + if (0 == threadIdx.x) { + dimC = (0 == dimA) ? 1 : 0; + zDimAstride = shape::stride(zShapeInfo)[dimA]; + xDimAstride = shape::stride(xShapeInfo)[dimA]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + xDimCstride = shape::stride(xShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong index = tid; index < batchValue*numOfSamples; index += gridDim.x * blockDim.x) { + + Nd4jLong nBatchIndex = index / numOfSamples; + Nd4jLong nSampleIndexInBatch = index - (nBatchIndex * numOfSamples); + + const X* xTad = x + (nBatchIndex * xDimCstride); + Z* zTad = z + (nBatchIndex * zDimCstride); + Z& arg = zTad[nSampleIndexInBatch * zDimAstride]; + + X Max = -minVal; + Nd4jLong nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; + Nd4jLong nClassPerSamples = nSampleIndexInBatch * numOfClassX; + + for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass++) { + Nd4jLong nIndex = nSamplesPerBatch + nClassPerSamples + nClass; + X tValue = (xTad[nClass * xDimAstride] - nd4j::math::nd4j_log(-nd4j::math::nd4j_log(devRng->relativeT(nIndex, minVal, maxVal)))); + if (tValue > Max) { + Max = tValue; + arg = nClass; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// +template +__host__ static void fillMultiNomialCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, + graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong batchValue, const Nd4jLong numOfSamples, + const Nd4jLong numOfClassX, const Nd4jLong dimA){ + + const X minVal = DataTypeUtils::min(); + const X maxVal = 1.0; + + fillMultiNomialCuda_ <<< blocksPerGrid, threadsPerBlock, 256, * stream >>> ( + devRng, vx, xShapeInfo, vz, zShapeInfo, batchValue, + numOfSamples, numOfClassX, dimA, minVal, maxVal); +} + +/////////////////////////////////////////////////////////////////// +void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { + + Nd4jLong dimA = (0 == dimC) ? 1 : 0; + + const Nd4jLong batchValue = output.sizeAt(dimC); + const Nd4jLong numOfClassX = input.sizeAt(dimA); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (batchValue * numOfSamples + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "fillMultinomial"); + graph::RandomGenerator *devRng; + + auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator)); + if (err != 0) { + cuda_exception::build("fillRandomMultiNomial: Cannot allocate device memory for random generator due error", err); + } + err = cudaStreamSynchronize(*context->getCudaStream()); + if (err != 0) { + cuda_exception::build("fillRandomMultiNomial: Cannot synchronize stream for random generator due error", err); + } + err = cudaMemcpyAsync(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice, *context->getCudaStream()); + if (err != 0) { + cuda_exception::build("fillRandomMultiNomial: Cannot copy random generator to device", err); + } + + NDArray::prepareSpecialUse({ &output }, { &input }); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillMultiNomialCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), devRng, input.getSpecialBuffer(), + input.getSpecialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), batchValue, numOfSamples, + numOfClassX, dimA), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({ &output }, { &input }); + manager.synchronize(); + + err = cudaFree(devRng); + if (err != 0) { + cuda_exception::build("fillRandomMultiNomial: Cannot deallocate device memory for random generator", err); + } + rng.rewindH(output.lengthOf() * numOfClassX); + } + } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu index 90e15b21f..15335d57e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu @@ -181,27 +181,21 @@ namespace helpers { auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); - for(int i = 0; i < inSubArrsSet->size(); ++i) { + for(int i = 0; i < inSubArrsSet.size(); ++i) { int numOfElemsToReverse = seqLengths->e(i); if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { - outSubArrsSet->at(i)->assign(inSubArrsSet->at(i)); + outSubArrsSet.at(i)->assign(inSubArrsSet.at(i)); } else { - auto inInnerSet = inSubArrsSet->at(i)->allTensorsAlongDimension({seqDim}); - auto outInnerSet = outSubArrsSet->at(i)->allTensorsAlongDimension({seqDim}); - for(int j = 0; j < inInnerSet->size(); ++j) - reverseArray(context, inInnerSet->at(j), outInnerSet->at(j), numOfElemsToReverse); - - delete inInnerSet; - delete outInnerSet; + auto inInnerSet = inSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + auto outInnerSet = outSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + for(int j = 0; j < inInnerSet.size(); ++j) + reverseArray(context, inInnerSet.at(j), outInnerSet.at(j), numOfElemsToReverse); } } - delete inSubArrsSet; - delete outSubArrsSet; } - } void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu index d843feeff..bc53946d3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu @@ -235,9 +235,9 @@ namespace helpers { for (size_t i = 0; i < axes.size(); i++) { int axe = axes[i]; if (axe == input->rankOf() - 1) { // last dimension - std::unique_ptr listOfTensors(output->allTensorsAlongDimension({axe})); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension({axe})); - int fullLen = listOfTensors->size(); + ResultSet listOfTensors = output->allTensorsAlongDimension({axe}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); + int fullLen = listOfTensors.size(); int theShift = shifts[i]; // if (theShift > 0) { // theShift %= fullLen; @@ -246,7 +246,7 @@ namespace helpers { // theShift -= fullLen * (theShift / fullLen - 1); // } for (int k = 0; k < fullLen; k++) { - rollFunctorLinear(output->getContext(), listOfTensors->at(k), listOfOutTensors->at(k), theShift, true); + rollFunctorLinear(output->getContext(), listOfTensors.at(k), listOfOutTensors.at(k), theShift, true); } } else { std::vector dims(input->rankOf() - axe - 1); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index cab6e50e7..9585642dd 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -212,7 +212,7 @@ namespace nd4j { NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); +// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), row, classes); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu index 49d388b2a..8ba3d40ce 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu @@ -29,7 +29,7 @@ namespace nd4j { return x >> shift; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -42,7 +42,7 @@ namespace nd4j { return x << shift; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -56,7 +56,7 @@ namespace nd4j { return x >> shift | x << step; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -70,7 +70,7 @@ namespace nd4j { return x << shift | x >> step; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu index 5ce883a59..76530269c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu @@ -33,7 +33,7 @@ namespace helpers { static FORCEINLINE NDArray activation(const NDArray& arr) { // return (const_cast&>(arr)).template transform>(); auto result = NDArray(&arr, false, arr.getContext()); - (const_cast(arr)).applyTransform(transform::Tanh, &result); + (const_cast(arr)).applyTransform(transform::Tanh, result); return result; } @@ -236,7 +236,7 @@ void sruBI(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const ND // x = x * mask if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask // U = x * w NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] @@ -497,7 +497,7 @@ void sruBIBP(nd4j::LaunchContext* context, NDArray* x, const NDArray* w, const N // x = x * mask if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask // U = x * w NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] @@ -522,7 +522,7 @@ void sruBIBP(nd4j::LaunchContext* context, NDArray* x, const NDArray* w, const N manager.synchronize(); // gradB - gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}); // [4*K] + gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] // gradW x->permutei({0, 2, 1}); // [time, bS, 2*K] -> [time, 2*K, bS] diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index b39ebf81b..4d1b18eef 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -148,24 +148,24 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, ND std::vector toDelete; if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); toDelete.push_back(pA); } if(S->ews() != 1) { - pS = S->dup('f'); + pS = new NDArray(S->dup('f')); toDelete.push_back(pS); } if(calcUV) { if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U->dup('f'); + pU = new NDArray(U->dup('f')); toDelete.push_back(pU); } if(pVT->ews() != 1 || pVT->ordering() == 'c') { - pVT = VT->dup('f'); + pVT = new NDArray(VT->dup('f')); toDelete.push_back(pVT); } } @@ -276,8 +276,8 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, N if(A->rankOf() != 2) throw std::runtime_error("svdJcb: rank of A array is not equal 2 !"); - auto m = A->sizeAt(0); - auto n = A->sizeAt(1); + int m = A->sizeAt(0); + int n = A->sizeAt(1); const int minDim = m < n ? m : n; if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S)) @@ -297,33 +297,53 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, N } NDArray* pA = const_cast(A); - NDArray* pS = S; - NDArray* pU = U; - NDArray* pV = V; + + const bool aForder = m == 1 || A->strideAt(0) == 1; + const bool aCorder = n == 1 || A->strideAt(1) == 1; + + const bool transA = !aForder && aCorder; + const bool dupA = !aForder && !aCorder; std::vector toDelete; - if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A->dup('f'); + if(dupA) { + pA = new NDArray(A->dup('f')); toDelete.push_back(pA); } + NDArray* pS = S; + if(S->ews() != 1) { - pS = S->dup('f'); + pS = new NDArray(S->dup('f')); toDelete.push_back(pS); } + NDArray *pU(nullptr), *pV(nullptr); + + int lda = transA ? pA->strideAt(0) : pA->strideAt(1); + int ldu(transA ? n : m), ldv(transA ? m : n); + bool uForder(true), vForder(true); + if(calcUV) { - if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U->dup('f'); + pU = transA ? V : U; + pV = transA ? U : V; + + uForder = pU->sizeAt(0) == 1 || pU->strideAt(0) == 1; + vForder = pV->sizeAt(0) == 1 || pV->strideAt(0) == 1; + + if(!uForder) { + pU = new NDArray(pU->dup('f')); toDelete.push_back(pU); } - if(pV->ews() != 1 || pV->ordering() == 'c') { - pV = V->dup('f'); + if(!vForder) { + pV = new NDArray(pV->dup('f')); toDelete.push_back(pV); } + + ldu = pU->strideAt(1); + ldv = pV->strideAt(1); } // create cusolverDn handle @@ -353,19 +373,27 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, N const cusolverEigMode_t jobz = calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; const int econ = !fullUV; - int lda(m), ldu(m), ldv(m); + if(transA) + math::nd4j_swap(m, n); - if(calcUV) { - ldu = pU->sizeAt(0); - ldv = pV->sizeAt(0); + // *** avoid bug in cuda API *** + void* nullPtr = nullptr; + NDArray* arrToAvoidBugInAPI = nullptr; + if(!calcUV && m != n) { + int maxDim = m > n ? m : n; + arrToAvoidBugInAPI = new NDArray('c', {maxDim, maxDim}, pA->dataType(), context); + nullPtr = arrToAvoidBugInAPI->getSpecialBuffer(); } + // ****************** + + NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); // query working space of SVD int lwork = 0; if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams); + status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams); + status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -380,14 +408,12 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, N PointersManager manager(context, "svdJcb"); - NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); - // choose appropriate cuda gemm api depending on data types if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -399,13 +425,20 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, N NDArray::registerSpecialUse({pS, pU, pV}, {pA}); - S->assign(pS); + if(S->ews() != 1) + S->assign(pS); if(calcUV) { - U->assign(pU); - V->assign(pV); + + if(!uForder) + U->assign(transA ? pV : pU); + if(!vForder) + V->assign(transA ? pU : pV); } + if(!calcUV && m != n) + delete arrToAvoidBugInAPI; + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; @@ -465,24 +498,24 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray* A, NDArray* std::vector toDelete; if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); toDelete.push_back(pA); } if(S->ews() != 1) { - pS = S->dup('f'); + pS = new NDArray(S->dup('f')); toDelete.push_back(pS); } if(calcUV) { if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U->dup('f'); + pU = new NDArray(U->dup('f')); toDelete.push_back(pU); } if(pV->ews() != 1 || pV->ordering() == 'c') { - pV = V->dup('f'); + pV = new NDArray(V->dup('f')); toDelete.push_back(pV); } } @@ -618,15 +651,12 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vectorallTensorsAlongDimension({S->rankOf() - 1}); if(calcUV) { - tadsU = U->allTensorsAlongDimension({U->rankOf() - 2, U->rankOf() - 1}); - tadsV = V->allTensorsAlongDimension({V->rankOf() - 2, V->rankOf() - 1}); + tadsU = new ResultSet(U->allTensorsAlongDimension({U->rankOf() - 2, U->rankOf() - 1})); + tadsV = new ResultSet(V->allTensorsAlongDimension({V->rankOf() - 2, V->rankOf() - 1})); } - for (int i = 0; i < tadsX->size(); ++i) - svdJcb(context, tadsX->at(i), tadsS->at(i), calcUV ? tadsU->at(i) : nullptr, calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV); - - delete tadsX; - delete tadsS; + for (int i = 0; i < tadsX.size(); ++i) + svdJcb(context, tadsX.at(i), tadsS.at(i), calcUV ? tadsU->at(i) : nullptr, calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV); if(calcUV) { delete tadsU; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu index 8c67cbf1b..bc1171efe 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu @@ -30,7 +30,7 @@ namespace nd4j { return ~_x;//eUtils::flip_bits(_x); }; - in.applyLambda(lambda, &out); + in.applyLambda(lambda, out); } BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu index 972013835..520a6115d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu @@ -251,7 +251,7 @@ int inTopKFunctor(nd4j::LaunchContext * context, const NDArray* predictions, con // we get top K values first if (k == 1) { - input->applyIndexReduce(indexreduce::IndexMax, indices, {input->rankOf() - 1}); + input->applyIndexReduce(indexreduce::IndexMax, *indices, {input->rankOf() - 1}); // copy values on specified indices topValuesMover<<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), indices->specialBuffer(), packI.platformShapeInfo(), packI.platformOffsets(), values->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, packX.numberOfTads(), k); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index 1a5a255ee..764b6abbf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -649,7 +649,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr int r = rng.relativeInt(i) % i; if(i != r) - subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r)); + subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); } } else { @@ -661,21 +661,19 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr for(int i = firstDim - 1; i > 0; --i) { int r = rng.relativeInt(i) % i; - subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r])); + subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); if(r == 0) isZeroShuffled = true; if(i != r) { - subArrsListOut->at(r)->assign(subArrsListIn->at(indices[i])); + subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); math::nd4j_swap(indices[i], indices[r]); } } if(!isZeroShuffled) - subArrsListOut->at(0)->assign(subArrsListIn->at(0)); - delete subArrsListOut; + subArrsListOut.at(0)->assign(subArrsListIn.at(0)); } rng.rewindH(firstDim-1); - delete subArrsListIn; } NDArray::registerSpecialUse({&output}, {&input}); @@ -747,7 +745,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr template static void clipByNorm_(nd4j::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, NDArray const& clipNormA, const bool isInplace) { const int rank = input.rankOf(); - auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions); + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); clipNormA.syncToHost(); //norm2.printBuffer("Norm2"); T const clipNorm = clipNormA.e(0); @@ -814,10 +812,10 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr globalNorm += l2norm * l2norm; } - globalNorm.applyTransform(transform::Sqrt, nullptr, nullptr);// = nd4j::math::nd4j_sqrt(globalNorm); + globalNorm.applyTransform(transform::Sqrt, globalNorm); // = nd4j::math::nd4j_sqrt(globalNorm); outputs[inputs.size()]->p(0, globalNorm); globalNorm.syncToHost(); - const T factor = clipNorm / globalNorm.e(0); + const T factor = static_cast(clipNorm) / globalNorm.e(0); for (size_t e = 0; e < inputs.size(); e++) { // all-reduce @@ -830,7 +828,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr else { auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - input->applyLambda(lambda, output); + input->applyLambda(lambda, *output); } } } @@ -848,7 +846,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr auto cn = clipNorm.e(0); if (dimensions.size() == 0) { // all-reduce - T n2 = input.reduceNumber(reduce::Norm2).e(0) / input.lengthOf(); + T n2 = input.reduceNumber(reduce::Norm2).e(0) / static_cast(input.lengthOf()); if (n2 <= cn) { if (!isInplace) output.assign(input); @@ -856,28 +854,26 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr else { const T factor = cn / n2; //auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - //input.applyLambda(lambda, &output); + //input.applyLambda(lambda, output); output.assign(input * factor); } } else { // along dimension - auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions, false); + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false); if (!isInplace) output.assign(input); auto tads = output.allTensorsAlongDimension(dimensions); auto outTads = output.allTensorsAlongDimension(dimensions); // TODO: make this CUDA-compliant somehow - for (int e = 0; e < tads->size(); e++) { - T n2 = norm2.e(e) / tads->at(e)->lengthOf(); + for (int e = 0; e < tads.size(); e++) { + T n2 = norm2.e(e) / static_cast(tads.at(e)->lengthOf()); const T factor = cn / n2; if (n2 > cn) { //auto lambda = LAMBDA_T(_x, factor) {return _x * factor;}; - tads->at(e)->applyScalar(scalar::Multiply, factor, outTads->at(e));//applyLambda(lambda, &output); + tads.at(e)->applyScalar(scalar::Multiply, factor, *outTads.at(e));//applyLambda(lambda, &output); } } - delete tads; - delete outTads; } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu new file mode 100644 index 000000000..8846be45c --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu @@ -0,0 +1,227 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// + +#include +#include +#include +#include +#include "../triangular_solve.h" + +namespace nd4j { + namespace ops { + namespace helpers { + /* + * lower triangular process for system of linear equations + * x_1 = b_1/a_1,1 + * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 + * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 + * ... + * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M + * + * output == x + * a == leftInput + * b == rightInput + * + * */ + template + static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, + T const* rightInput, Nd4jLong const* rightInputShape, + bool const adjoint, T* output, Nd4jLong* outputShape, + Nd4jLong rows) { + + for (auto r = 0; r < rows; r++) { + Nd4jLong posY[] = {r, 0}; + Nd4jLong posX[] = {r, r}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); + + auto sum = rightInput[yIndex]; + for (auto c = 0; c < r; c++) { + Nd4jLong posZ[] = {c, 0}; + Nd4jLong pos[] = {r, c}; + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = sum / leftInput[xIndex]; + } + } + + /* + * upper triangular process for system of linear equations + * x_M = b_M/a_M,M + * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 + * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 + * ... + * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 + * + * output == x + * a == leftInput + * b == rightInput + * + * */ + + template + static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, + T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output, + Nd4jLong* outputShape, Nd4jLong rows) { + + for (auto r = rows; r > 0; r--) { + Nd4jLong posY[] = {r - 1, 0}; + Nd4jLong posX[] = {r - 1, r - 1}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); + auto sum = rightInput[yIndex]; + for (auto c = r; c < rows; c++) { + Nd4jLong posZ[] = {c, 0}; + Nd4jLong pos[] = {r - 1, c}; + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = sum / leftInput[xIndex]; + } + } + + template + static __global__ void triangularSolveKernel(T const* leftInput, Nd4jLong const* leftPartShape, + T const* rightInput, Nd4jLong const* rightPartShape, bool const lower, bool const adjoint, T* output, + Nd4jLong* outputShape, Nd4jLong* tadLeftShape, Nd4jLong* tadLeftOffset, Nd4jLong* tadRightShape, + Nd4jLong* tadRightOffset, Nd4jLong* tadOutputShape, Nd4jLong* tadOutputOffset, Nd4jLong batchNum) { + + __shared__ Nd4jLong rows; + if (threadIdx.x == 0) { + rows = shape::sizeAt(leftPartShape, -2); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto stop = batchNum; + auto increment = blockDim.x * gridDim.x; + + for (auto i = start; i < stop; i += increment) { + auto pLeftPart = leftInput + tadLeftOffset[i]; + auto pRightPart = rightInput + tadRightOffset[i]; + auto pOutputPart = output + tadOutputOffset[i]; + if (lower) { + lowerTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows); + } else { + upperTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows); + } + } + } + + template + static int triangularSolveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, + bool lower, bool adjoint, NDArray* output) { + NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); + auto leftTads = ConstantTadHelper::getInstance()->tadForDimensions(leftInput->getShapeInfo(), {-2, -1}); + auto rightTads = ConstantTadHelper::getInstance()->tadForDimensions(rightInput->getShapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1}); + + auto stream = context->getCudaStream(); + T const* leftBuf = reinterpret_cast(leftInput->getSpecialBuffer()); + T const* rightBuf = reinterpret_cast(rightInput->getSpecialBuffer()); + T* outputBuf = reinterpret_cast(output->specialBuffer()); + triangularSolveKernel<<<128, 128, 256, *stream>>>(leftBuf, leftInput->getSpecialShapeInfo(), + rightBuf, rightInput->getSpecialShapeInfo(), lower, adjoint, outputBuf, output->specialShapeInfo(), + leftTads.specialShapeInfo(), leftTads.specialOffsets(), rightTads.specialShapeInfo(), + rightTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets(), + leftTads.numberOfTads()); + + NDArray::registerSpecialUse({output}, {leftInput, rightInput}); + + return Status::OK(); + + } + + int triangularSolveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE); + } + + template + static __global__ void upperAdjointKernel(T const* input, T* output, + Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, + Nd4jLong* inputTads, Nd4jLong* inputOffsets, Nd4jLong* outputTads, Nd4jLong* outputOffsets) { + + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto inputPart = input + inputOffsets[b]; + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = threadIdx.y; c <= r; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(inputTads, xPos); + outputPart[zIndex] = inputPart[xIndex]; + } + } + } + + } + + template + static __global__ void lowerAdjointKernel(T const* input, T* output, + Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, + Nd4jLong* inputTads, Nd4jLong* inputOffsets, Nd4jLong* outputTads, Nd4jLong* outputOffsets) { + + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto inputPart = input + inputOffsets[b]; + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = r + threadIdx.y; c < columns; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(inputTads, xPos); + outputPart[zIndex] = inputPart[xIndex]; + } + } + } + } + + template + static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, + NDArray* output) { + + auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + auto inputBuf = reinterpret_cast(input->getSpecialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()); + auto rows = input->sizeAt(-2); + auto columns = input->sizeAt(-1); + + if (lower) { + lowerAdjointKernel<<<128, 256, 256, *stream>>>(inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, inputTads.specialShapeInfo(), inputTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets()); + } else { + upperAdjointKernel<<<128, 256, 256, *stream>>>(inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, inputTads.specialShapeInfo(), inputTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets()); + } + } + + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, (context, input, lower, output), FLOAT_NATIVE); + } + + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/helpers.h b/libnd4j/include/ops/declarable/helpers/helpers.h index f2e19063e..f3aebc7b7 100644 --- a/libnd4j/include/ops/declarable/helpers/helpers.h +++ b/libnd4j/include/ops/declarable/helpers/helpers.h @@ -41,6 +41,9 @@ #include #include #include + +#include + #endif // CUDACC #endif // LIBND4J_HELPERS_H diff --git a/libnd4j/include/ops/declarable/helpers/image_resize.h b/libnd4j/include/ops/declarable/helpers/image_resize.h index d52fd74f7..047b2cf70 100644 --- a/libnd4j/include/ops/declarable/helpers/image_resize.h +++ b/libnd4j/include/ops/declarable/helpers/image_resize.h @@ -45,6 +45,9 @@ namespace helpers { bool preserveAspectRatio, bool antialias, NDArray* output); int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, bool const alignCorners, bool const halfPixelAlign, NDArray* output); + int resizeAreaFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, + bool const alignCorners, NDArray* output); + int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output); diff --git a/libnd4j/include/ops/declarable/helpers/imagesHelpers.h b/libnd4j/include/ops/declarable/helpers/imagesHelpers.h new file mode 100644 index 000000000..0ae8ba072 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/imagesHelpers.h @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// +// +// @author AbdelRauf (rauf@konduit.ai) +// + +#ifndef LIBND4J_HELPERS_IMAGES_H +#define LIBND4J_HELPERS_IMAGES_H + +#include +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); + + void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); + + void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); + void transformYuvRgb(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); + void transformRgbYuv(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); + + void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); + + void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); +} +} +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/choose.cpp b/libnd4j/include/ops/declarable/helpers/impl/choose.cpp index 4fb32e2f8..a75298af6 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/choose.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/choose.cpp @@ -46,7 +46,7 @@ namespace helpers { // nd4j::NDArray comp1 = *comp; for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { T result2 = processElementCondition(mode, arg->e(i), comp->e(0)); - if(result2 > 0) { + if(result2 > static_cast(0)) { if (output != nullptr) output->p(numResults, arg->e(i)); numResults++; @@ -59,7 +59,7 @@ namespace helpers { nd4j::NDArray arg1 = *arg; for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { T result2 = processElementCondition(mode, arg->e(i), comp->e(i)); - if(result2 > 0) { + if(result2 > static_cast(0)) { if (output != nullptr) output->p(numResults, arg->e(i)); numResults++; @@ -74,7 +74,7 @@ namespace helpers { //for comparison for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { T result2 = processElementCondition(mode, arg->e(i), compScalar.e(0)); - if(result2 > 0) { + if(result2 > static_cast(0)) { if (output != nullptr) output->p(numResults, arg->e(i)); numResults++; diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 528642bb6..2b65d0c8e 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -130,7 +130,7 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation if(params[2] != 0) - c->applyScalar(scalar::LstmClip, params[2]); + c->applyScalar(scalar::LstmClip, params[2], *c); // peephole connections for output gate if(Wp != nullptr) @@ -206,22 +206,22 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes - xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nIn] + xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] if(h) - hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nOut] + hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut] } else { dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis - xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [nIn] - h0Set = h0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] - c0Set = c0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] - ctSet = ct->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + h0Set = new ResultSet(h0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + c0Set = new ResultSet(c0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + ctSet = new ResultSet(ct->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] if(h) - hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [nOut] + hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] if(ht) - htSet = ht->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + htSet = new ResultSet(ht->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] } // loops diff --git a/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp b/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp index 179c7efab..3c65f740d 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp @@ -42,7 +42,7 @@ void rnnCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* Wx // ht is current cell output [bS x nU], that is at current time step t ht->assign(mmul(*xt, *Wx) + (*b)({{0, nU}}) + mmul(*hPrev, *Wh) + (*b)({{nU, 2*nU}})); // [bS x nU] + [nU] + [bS x nU] + [nU] = [bS x nU] - ht->applyTransform(transform::Tanh); + ht->applyTransform(transform::Tanh, *ht); } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp b/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp new file mode 100644 index 000000000..e21499314 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp @@ -0,0 +1,123 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#include +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + static void fill_(const void *vvalues, const void *vindices, void *voutput, const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) { + auto values = reinterpret_cast(vvalues); + auto indices = reinterpret_cast(vindices); + auto output = reinterpret_cast(voutput); + + Nd4jLong coords[MAX_RANK]; + uint64_t pos = 0; + for (uint64_t e = 0L; e < length; e++) { + // indices come in blocks + for (uint8_t p = 0; p < rank; p++) { + coords[p] = indices[pos++]; + } + + // fill output at given coords with sparse value + output[shape::getOffset(zShapeInfo, coords)] = values[e]; + } + + } + + void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output) { + // make sure host buffer is updated + values.syncToHost(); + indices.syncToHost(); + + auto rank = output.rankOf(); + + if (output.isS()) { + // string case is not so trivial, since elements might, and probably will, have different sizes + auto numValues = values.lengthOf(); + auto numElements = output.lengthOf(); + + // first of all we calculate final buffer sizes and offsets + auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def); + auto valuesLength = StringUtils::byteLength(values); + auto bufferLength = defaultLength * (output.lengthOf() - numValues) + valuesLength; + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements); + + // now we make sure our output buffer can hold results + output.dataBuffer()->expand( bufferLength + headerLength); + + std::vector outputCoords(rank); + std::vector valueCoords(rank); + + auto offsetsBuffer = output.bufferAsT(); + auto dataBuffer = reinterpret_cast(offsetsBuffer + output.lengthOf()); + + offsetsBuffer[0] = 0; + + // getting initial value coords + for (int e = 0; e < rank; e++) + valueCoords[e] = indices.e(e); + + // write results individually + for (uint64_t e = 0; e < numElements; e++) { + auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data()); + auto cLength = 0L; + std::string str; + if (vIndex == e) { + // we're writing down sparse value here + str = values.e(e); + } else { + // we're writing down default value if it exists + if (def != nullptr) + str = def->e(0); + else + str = ""; + } + + // TODO: make it unicode compliant + memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length()); + + // writing down offset + offsetsBuffer[e+1] = cLength; + } + } else { + // numeric case is trivial, since all elements have equal sizes + + // write out default values, if they are present + if (def != nullptr) { + output.assign(def); + + // make sure output is synced back + output.syncToHost(); + } + + // write out values + BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, (values.getBuffer(), indices.getBuffer(), output.buffer(), output.getShapeInfo(), rank, values.lengthOf()), LIBND4J_TYPES, INDEXING_TYPES); + } + // copy back to device, if there's any + output.syncToDevice(); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/lgamma.h b/libnd4j/include/ops/declarable/helpers/lgamma.h new file mode 100644 index 000000000..48bcf1d73 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/lgamma.h @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#ifndef __LIBND4J_L_GAMMA__H__ +#define __LIBND4J_L_GAMMA__H__ + +#include +#include "NDArray.h" + +namespace nd4j { +namespace ops { +namespace helpers { + + // calculate the digamma function for each element for array + void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z); + +} +} +} + + +#endif //__LIBND4J_L_GAMMA__H__ diff --git a/libnd4j/include/ops/declarable/helpers/lstm.h b/libnd4j/include/ops/declarable/helpers/lstm.h index 91ca87738..9c0df2fa5 100644 --- a/libnd4j/include/ops/declarable/helpers/lstm.h +++ b/libnd4j/include/ops/declarable/helpers/lstm.h @@ -33,7 +33,7 @@ namespace helpers { } static FORCEINLINE void sigmoidInplace(const NDArray& arr) { - (const_cast(arr)).applyTransform(transform::Sigmoid); + (const_cast(arr)).applyTransform(transform::Sigmoid, const_cast(arr)); } ////////////////////////////////////////////////////////////////////////// @@ -42,7 +42,7 @@ namespace helpers { } static FORCEINLINE void tanhInplace(const NDArray& arr) { - (const_cast(arr)).applyTransform(transform::Tanh); + (const_cast(arr)).applyTransform(transform::Tanh, const_cast(arr)); } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index d0bc16b66..a52c2c0e5 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -46,41 +46,41 @@ static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float switch (opId) { case 0: - (const_cast(x)).applyTransform(transform::Tanh, &z); + (const_cast(x)).applyTransform(transform::Tanh, z); break; case 1: - (const_cast(x)).applyScalar(scalar::RELU, 0, &z); + (const_cast(x)).applyScalar(scalar::RELU, 0, z); break; case 2: - (const_cast(x)).applyTransform(transform::Sigmoid, &z); + (const_cast(x)).applyTransform(transform::Sigmoid, z); break; case 3: { ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::Affine, &z, &args); + (const_cast(x)).applyTransform(transform::Affine, z, &args); break; } case 4: - (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, &z); + (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); break; case 5: helpers::thresholdRelu(x.getContext(), x, alpha, z); break; case 6: { ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::ScaledTanh, &z, &args); + (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); break; } case 7: - (const_cast(x)).applyTransform(transform::HardSigmoid, &z); + (const_cast(x)).applyTransform(transform::HardSigmoid, z); break; case 8: - (const_cast(x)).applyScalar(scalar::ELU, alpha, &z); + (const_cast(x)).applyScalar(scalar::ELU, alpha, z); break; case 9: - (const_cast(x)).applyTransform(transform::SoftSign, &z); + (const_cast(x)).applyTransform(transform::SoftSign, z); break; case 10: - (const_cast(x)).applyTransform(transform::SoftPlus, &z); + (const_cast(x)).applyTransform(transform::SoftPlus, z); break; default: throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); diff --git a/libnd4j/include/ops/declarable/helpers/lup.h b/libnd4j/include/ops/declarable/helpers/lup.h index 96ec9bec1..ae10e6136 100644 --- a/libnd4j/include/ops/declarable/helpers/lup.h +++ b/libnd4j/include/ops/declarable/helpers/lup.h @@ -26,9 +26,8 @@ namespace nd4j { namespace ops { namespace helpers { - template - T lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation); - + int lup(nd4j::LaunchContext* context, NDArray* input, NDArray* lu, NDArray* permutation); + void lu(nd4j::LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation); int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output); int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output); diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/SystemTime.java b/libnd4j/include/ops/declarable/helpers/print_variable.h similarity index 67% rename from nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/SystemTime.java rename to libnd4j/include/ops/declarable/helpers/print_variable.h index cb2136309..3521e38b9 100644 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/SystemTime.java +++ b/libnd4j/include/ops/declarable/helpers/print_variable.h @@ -14,25 +14,21 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.kafka; +// +// @author raver119@gmail.com +// +#ifndef LIBND4J_PRINT_VARIABLE_H +#define LIBND4J_PRINT_VARIABLE_H -import kafka.utils.Time; +#include -class SystemTime implements Time { - public long milliseconds() { - return System.currentTimeMillis(); - } - - public long nanoseconds() { - return System.nanoTime(); - } - - public void sleep(long ms) { - try { - Thread.sleep(ms); - } catch (InterruptedException e) { - // Ignore +namespace nd4j { + namespace ops { + namespace helpers { + void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message = {}); } } } + +#endif //LIBND4J_PRINT_VARIABLE_H diff --git a/libnd4j/include/ops/declarable/helpers/qr.h b/libnd4j/include/ops/declarable/helpers/qr.h new file mode 100644 index 000000000..33649e7c8 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/qr.h @@ -0,0 +1,35 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// +#ifndef __QR__H_HELPERS__ +#define __QR__H_HELPERS__ +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + void qr(nd4j::LaunchContext * context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies); + + +} +} +} +#endif diff --git a/libnd4j/include/ops/declarable/helpers/random.h b/libnd4j/include/ops/declarable/helpers/random.h index db1b8ae53..c97aae118 100644 --- a/libnd4j/include/ops/declarable/helpers/random.h +++ b/libnd4j/include/ops/declarable/helpers/random.h @@ -34,6 +34,7 @@ namespace helpers { void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output); void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output); void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output); + void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC); } } } diff --git a/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h b/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h new file mode 100644 index 000000000..8d00639de --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SAMEDIFF_SPARSE_TO_DENSE_H +#define SAMEDIFF_SPARSE_TO_DENSE_H + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output); + } + } +} + +#endif //SAMEDIFF_SPARSE_TO_DENSE_H diff --git a/libnd4j/include/ops/declarable/helpers/triangular_solve.h b/libnd4j/include/ops/declarable/helpers/triangular_solve.h new file mode 100644 index 000000000..a40a3e144 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/triangular_solve.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// +#ifndef __TRIANGULAR_SOLVE__H_HELPERS__ +#define __TRIANGULAR_SOLVE__H_HELPERS__ +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + int triangularSolveFunctor(nd4j::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output); + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output); +} +} +} +#endif diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 3aef09bcd..8d5cb90d4 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -189,6 +189,11 @@ namespace nd4j { shapeStart = std::chrono::system_clock::now(); } + // if we override shape function, we'll return size of fastPath + if (ctx.isFastPath() && ctx.shapeFunctionOverride()) { + return (int) ctx.fastpath_out().size(); + } + auto outSha = this->calculateOutputShape(&inSha, ctx); results = outSha->size(); @@ -530,8 +535,8 @@ namespace nd4j { // platform helpers use might be forbidden for various reasons, so we'll check it out first if (block->helpersAllowed() && nd4j::Environment::getInstance()->helpersAllowed()) { // if we have platform-specific helper for this op - invoke it - if (OpRegistrator::getInstance()->hasHelper(this->getOpHash())) { - auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash()); + if (OpRegistrator::getInstance()->hasHelper(this->getOpHash(), block->engine())) { + auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash(), block->engine()); if (helper->isUsable(*block)) { status = helper->invokeHelper(*block); hasHelper = true; diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp index 3e35e2c11..040cde77c 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp @@ -38,7 +38,7 @@ namespace nd4j { } LegacyScalarBoolOp::LegacyScalarBoolOp(int opNum, NDArray &scalar) : LegacyOp::LegacyOp(1, opNum){ - _scalar = scalar.dup(scalar.ordering()); + _scalar = new NDArray(scalar.dup(scalar.ordering())); } ShapeList *LegacyScalarBoolOp::calculateOutputShape(ShapeList *inputShape, nd4j::graph::Context &block) { diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp index 581bdae4c..b1261b37c 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp @@ -38,7 +38,7 @@ namespace nd4j { } LegacyScalarOp::LegacyScalarOp(int opNum, NDArray &scalar) : LegacyOp::LegacyOp(1, opNum){ - _scalar = scalar.dup(scalar.ordering()); + _scalar = new NDArray(scalar.dup(scalar.ordering())); } ShapeList *LegacyScalarOp::calculateOutputShape(ShapeList *inputShape, nd4j::graph::Context &block) { @@ -69,9 +69,9 @@ namespace nd4j { } else if (block.getTArguments()->size() > 0) { auto y = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); - NDArray::prepareSpecialUse({z}, {x, &y}); - - NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->getBuffer(), x->getShapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->getBuffer(), z->getShapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), extras.argumentsAsT(z->dataType(), 1)); + x->applyScalarArr(static_cast(opNum), y, *z); + // NDArray::prepareSpecialUse({z}, {x, &y}); + // NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->getBuffer(), x->getShapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->getBuffer(), z->getShapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), extras.argumentsAsT(z->dataType(), 1)); manager.synchronize(); } else { diff --git a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp index a42203162..09e4ec58f 100644 --- a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp +++ b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp @@ -173,15 +173,18 @@ namespace nd4j { } void OpRegistrator::registerHelper(nd4j::ops::platforms::PlatformHelper* op) { - if (_helpersLH.count(op->hash()) > 0) + std::pair p = {op->hash(), op->engine()}; + if (_helpersLH.count(p) > 0) throw std::runtime_error("Tried to double register PlatformHelper"); _uniqueH.emplace_back(op); - std::pair pair(op->name(), op); + nd4j_debug("Adding helper for op \"%s\": [%lld - %i]\n", op->name().c_str(), op->hash(), (int) op->engine()); + + std::pair, nd4j::ops::platforms::PlatformHelper*> pair({op->name(), op->engine()}, op); _helpersH.insert(pair); - std::pair pair2(op->hash(), op); + std::pair, nd4j::ops::platforms::PlatformHelper*> pair2(p, op); _helpersLH.insert(pair2); } @@ -227,15 +230,17 @@ namespace nd4j { return _declarablesD.at(name); } - nd4j::ops::platforms::PlatformHelper* OpRegistrator::getPlatformHelper(Nd4jLong hash) { - if (_helpersLH.count(hash) == 0) + nd4j::ops::platforms::PlatformHelper* OpRegistrator::getPlatformHelper(Nd4jLong hash, samediff::Engine engine) { + std::pair p = {hash, engine}; + if (_helpersLH.count(p) == 0) throw std::runtime_error("Requested helper can't be found"); - return _helpersLH[hash]; + return _helpersLH[p]; } - bool OpRegistrator::hasHelper(Nd4jLong hash) { - return _helpersLH.count(hash) > 0; + bool OpRegistrator::hasHelper(Nd4jLong hash, samediff::Engine engine) { + std::pair p = {hash, engine}; + return _helpersLH.count(p) > 0; } int OpRegistrator::numberOfOperations() { diff --git a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp index 75dc6e2c4..86c84b0fb 100644 --- a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp +++ b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp @@ -24,10 +24,11 @@ namespace nd4j { namespace ops { namespace platforms { - PlatformHelper::PlatformHelper(const char *name) { + PlatformHelper::PlatformHelper(const char *name, samediff::Engine engine) { // we just store name/hash of target operation _name = std::string(name); _hash = HashHelper::getInstance()->getLongHash(_name); + _engine = engine; } nd4j::NDArray *PlatformHelper::getZ(graph::Context &ctx, int inputId) { @@ -74,6 +75,10 @@ namespace nd4j { return z; } + samediff::Engine PlatformHelper::engine() { + return _engine; + } + std::string PlatformHelper::name() { return _name; } diff --git a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu new file mode 100644 index 000000000..3bd1357bf --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu @@ -0,0 +1,275 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + +////////////////////////////////////////////////////////////////////////// +static void batchnormCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* mean, const NDArray* variance, + const NDArray* gamma, const NDArray* beta, + NDArray* output, + const double epsilon, const bool isSpatialMode) { + + + // input, output -> 4D:nchw, 5D:ncdhw + // mean, variance, gamma, beta -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode + // -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for BATCHNORM_MODE_PER_ACTIVATION mode + + const cudnnDataType_t dataType = cudnnDataType(input->dataType()); + + const int xRank = input->rankOf(); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err); + + const std::vector xShape = input->getShapeAsVectorInt(); // input and output have same shapes + + std::vector paramsShape, paramsStrides; // mean, variance, gamma and beta have same shapes + if(isSpatialMode) { // 1xCx1x1 + const int iC = mean->lengthOf(); + const int stride0 = mean->strideAt(0); + paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) : std::vector({1, iC, 1, 1, 1}); + paramsStrides = xRank == 4 ? std::vector({iC*stride0, stride0, 1, 1}) : std::vector({iC*stride0, stride0, 1, 1, 1}); + } + else { + paramsShape = mean->getShapeAsVectorInt(); + paramsStrides = xRank == 4 ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3)}) : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3), (int)mean->strideAt(4)}); + } + + std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3)}; + std::vector zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3)}; + + if(xRank > 4) { // 5D + xStrides.push_back((int)input->strideAt(4)); + zStrides.push_back((int)output->strideAt(4)); + } + + cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), xStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if(output->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(z, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(z, dataType, xRank, xShape.data(), zStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err); + + // mean, variance, gamma and beta descriptor, the same descriptor for all of them + cudnnTensorDescriptor_t params; + cudnnCreateTensorDescriptor(¶ms); + if(mean->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, paramsShape.data()); + else + err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/beta failed", err); + + + if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetConvolutionNdDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* ptrAlpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* ptrBeta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); + + // calculations + err = cudnnBatchNormalizationForwardInference(*handle, isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, + ptrAlpha, ptrBeta, + x, input->getSpecialBuffer(), + z, output->getSpecialBuffer(), + params, + gamma ? gamma->getSpecialBuffer(): nullptr, + beta ? beta->getSpecialBuffer() : nullptr, + mean->getSpecialBuffer(), variance->getSpecialBuffer(), epsilon); + + if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err); + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr); + + + NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); +} + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(batchnorm, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto mean = INPUT_VARIABLE(1); + auto variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + + auto output = OUTPUT_VARIABLE(0); + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const double epsilon = T_ARG(0); + + if(applyScale) + gamma = INPUT_VARIABLE(3); + if(applyOffset) + beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank-1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM CUDNN op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes + // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} + std::vector expShape; + if(numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} + expShape = std::vector(inRank, 1); + for(uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape) , 0, "BATCHNORM CUDNN op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM CUDNN op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); + if(gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM CUDNN op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); + if(beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM CUDNN op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same + for(int i = 1; i < block.width(); ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM CUDNN op: types of all input arrays should be the same !"); + + // cudnn supports NCHW format only + const bool needPermut = axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1); + + if(needPermut) { // if NHWC + std::vector perm = {0, 3, 1, 2}; // NHWC -> NCHW + input = new NDArray(input->permute(perm)); + output = new NDArray(output->permute(perm)); + } + + // calculations + batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, output, epsilon, axes.size() == 1); + + if(needPermut) { + delete input; + delete output; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(batchnorm, ENGINE_CUDA) { + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = applyScale ? INPUT_VARIABLE(3) : nullptr; + NDArray* beta = applyOffset ? INPUT_VARIABLE(3 + (int)applyScale) : nullptr; + + const int numOfIntArgs = block.getIArguments()->size(); + const int xRank = input->rankOf(); + + // disable cudnn batchnorm so far + return false; + + // *********************************** // + if(xRank != 4 && xRank != 5) + return false; + + // *********************************** // + const bool badType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; + if(badType) + return false; + + // *********************************** // + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(xRank-1); // default dimension to reduce along is last dimension + + if(axes.size() != 1 && axes.size() != 3 && axes.size() != 4) + return false; + + // *********************************** // + bool allParamsHaveSameShapeAndStrides = shape::haveSameShapeAndStrides(mean->getShapeInfo(), variance->getShapeInfo()); + if(gamma) + allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gamma->getShapeInfo()); + if(beta) + allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), beta->getShapeInfo()); + + if(!allParamsHaveSameShapeAndStrides) + return false; + + // *********************************** // + bool isFormatGood = false; + if(axes.size() == 1) + isFormatGood = mean->lengthOf() == input->sizeAt(1) || mean->lengthOf() == input->sizeAt(-1); // mean [C] + else { + auto inputShapeModif = input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or [dim0,dim1,dim2,dim3,dim4] + inputShapeModif[0] = 1; + isFormatGood = mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or [1,dim1,dim2,dim3,dim4] + } + if(!isFormatGood) + return false; + + return true; +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu new file mode 100644 index 000000000..234dbffb7 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu @@ -0,0 +1,521 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + +////////////////////////////////////////////////////////////////////////// +static void conv2dCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const int paddingMode, const bool isNCHW) { + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); + + // weights descriptor + cudnnFilterDescriptor_t w; + cudnnCreateFilterDescriptor(&w); + err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, oC, iC, kH, kW); + if(err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if(output->ews() == 1) + err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); + if (err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudnnSetConvolution2dDescriptor failed", err); + + // algorithm description + cudnnConvolutionFwdAlgo_t algo; + err = cudnnGetConvolutionForwardAlgorithm(*handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); + if (err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); + + + // allocate auxiliary device memory, abbreviation ws means workspace + size_t wsSize; + err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, &wsSize); + if (err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); + void* wsData; + auto cudaErr = cudaMalloc(&wsData, wsSize); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudaMalloc for auxiliary workspace memory failed", cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, weights, bias}); + + // run calculation + err = cudnnConvolutionForward(*handle, alpha, x, input->getSpecialBuffer(), w, weights->getSpecialBuffer(), conv, algo, wsData, wsSize, beta, z, output->specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudnnConvolutionForward failed", err); + + // add bias if it is present + if (bias != nullptr) { + + cudnnTensorDescriptor_t b; + cudnnCreateTensorDescriptor(&b); + err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); + if (err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); + err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", err); + } + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv2dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + cudaErr = cudaFree(wsData); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv2dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); + + NDArray::registerSpecialUse({output}, {input, weights, bias}); +} + +////////////////////////////////////////////////////////////////////////// +static void conv2dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* weights, const NDArray* gradO, + NDArray* gradI, NDArray* gradW, NDArray* gradB, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const int paddingMode, const bool isNCHW) { + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if(gradO->ews() == 1) + err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if(gradI->ews() == 1) + err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradI failed", err); + + // gradW descriptor + cudnnFilterDescriptor_t dw; + cudnnCreateFilterDescriptor(&dw); + err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, oC, iC, kH, kW); + if(err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnSetConvolution2dDescriptor failed", err); + + // gradW algorithm description + cudnnConvolutionBwdFilterAlgo_t algoGradW; + err = cudnnGetConvolutionBackwardFilterAlgorithm(*handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, &algoGradW); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", err); + + // gradI algorithm description + cudnnConvolutionBwdDataAlgo_t algoGradI; + err = cudnnGetConvolutionBackwardDataAlgorithm(*handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, &algoGradI); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); + + // allocate auxiliary device memory for gradW calculation, abbreviation ws means workspace + size_t wsGradWSize; + err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, algoGradW, &wsGradWSize); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", err); + void* wsGradWData; + auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData failed", cudaErr); + + // allocate auxiliary device memory for gradI calculation, abbreviation ws means workspace + size_t wsGradISize; + err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, algoGradI, &wsGradISize); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", err); + void* wsGradIData; + cudaErr = cudaMalloc(&wsGradIData, wsGradISize); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData failed", cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + + // run calculation for gradB (if not nullptr) + if(gradB != nullptr) { + cudnnTensorDescriptor_t db; + cudnnCreateTensorDescriptor(&db); + err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); + + err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnConvolutionBackwardBias failed", err); + } + + // run calculation for gradW + err = cudnnConvolutionBackwardFilter(*handle, alpha, x, input->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, gradW->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); + + // run calculation for gradI + err = cudnnConvolutionBackwardData(*handle, alpha, dw, weights->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), conv, algoGradI, wsGradIData, wsGradISize, beta, dx, gradI->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudnnConvolutionBackwardData failed", err); + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + cudaErr = cudaFree(wsGradWData); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData failed", cudaErr); + cudaErr = cudaFree(wsGradIData); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData failed", cudaErr); + + NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(conv2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width + + REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + std::vector expectedWeightsShape = {kH, kW, iC, oC}; + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if (bias) { + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + REQUIRE_TRUE((bias->rankOf() == 1 && bias->strideAt(0) == 1) || (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && bias->strideAt(1) == 1) || (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && bias->strideAt(0) == 1), 0, "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !"); + } + + NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} + newWeights->assign(weights->permute({3,2,0,1})); // permute weights (kH, kW, iC, oC --> oC, iC, kH, kW) + + NDArray* newInput = input; + NDArray* newGradI = nullptr; + if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); + + conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW); + + if(newInput != input) + delete newInput; + + delete newWeights; + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(conv2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + + const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + + return paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); + std::vector expectedWeightsShape = {kH, kW, iC, oC}; + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if(bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + NDArray* newGradW = new NDArray(gradW->ordering(), {oC, iC, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} + NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kH, kW}, weights->dataType(), weights->getContext()); + + newWeights->assign(weights->permute({3,2,0,1})); // permute weights (kH, kW, iC, oC --> oC, iC, kH, kW) + + NDArray* newInput = input; + NDArray* newGradI = gradI; + if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); + + conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); + + newGradW->permutei({2,3,1,0}); // [oC, iC, kH, kW] -> [kH, kW, iC, oC] + gradW->assign(newGradW); + + if(newInput != input) { + + if(isNCHW) + gradI->assign((*newGradI)({0,0, 0,0, 0,gradI->sizeAt(2), 0,gradI->sizeAt(3)})); + else + gradI->assign((*newGradI)({0,0, 0,gradI->sizeAt(1), 0,gradI->sizeAt(2), 0,0})); + + delete newInput; + delete newGradI; + } + + delete newWeights; + delete newGradW; + + return Status::OK(); +} + +PLATFORM_CHECK(conv2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + + const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; + const bool badGradOType = gradO->dataType() != DataType::DOUBLE && gradO->dataType() != DataType::FLOAT32 && gradO->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + + return isNCHW && paddingMode != 2 && !badInputType && !badWeightsType && !badGradOType && !badBiasType; +} + + + + + + + +// PLATFORM_IMPL(conv2d, ENGINE_CUDA) { + +// auto handle = reinterpret_cast(block.launchContext()->getCuDnnHandle()); +// auto res = cudnnSetStream(*handle, *block.launchContext()->getCudaStream()); +// if (res != 0) +// throw nd4j::cuda_exception::build("Can't set stream for cuDNN", res); + +// auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) +// auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always +// auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + +// auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + +// NDArray::prepareSpecialUse({output}, {input, weights, bias}); + +// int sH = INT_ARG(2); // strides height +// int sW = INT_ARG(3); // strides width +// int pH = INT_ARG(4); // paddings height +// int pW = INT_ARG(5); // paddings width +// int dH = INT_ARG(6); // dilations height +// int dW = INT_ARG(7); // dilations width +// int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME +// bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + +// int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height +// int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width + +// int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; +// int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes +// ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); +// ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, isSameMode); + +// auto dtype = cudnnDataType(input->dataType()); + + +// cudnnTensorDescriptor_t src; +// cudnnCreateTensorDescriptor(&src); +// res = cudnnSetTensor4dDescriptorEx(src, dtype, input->sizeAt(0), input->sizeAt(1), input->sizeAt(2), input->sizeAt(3), input->strideAt(0), input->strideAt(1), input->strideAt(2), input->strideAt(3)); +// if (res != 0) +// throw nd4j::cuda_exception::build("cudnnSetTensor4dDescriptorEx src failed", res); + +// // TODO: we definitely want NHWC here as well +// cudnnFilterDescriptor_t wght; +// cudnnCreateFilterDescriptor(&wght); +// res = cudnnSetFilter4dDescriptor(wght, dtype, CUDNN_TENSOR_NCHW, oC, iC, kH, kW); +// if (res != 0) +// throw nd4j::cuda_exception::build("cudnnSetFilter4dDescriptor failed", res); + +// cudnnConvolutionDescriptor_t cdc; +// cudnnCreateConvolutionDescriptor(&cdc); +// res = cudnnSetConvolution2dDescriptor(cdc, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, dtype); +// if (res != 0) +// throw nd4j::cuda_exception::build("cudnnSetConvolution2dDescriptor failed", res); + +// cudnnTensorDescriptor_t dst; +// cudnnCreateTensorDescriptor(&dst); +// res = cudnnSetTensor4dDescriptorEx(dst, dtype, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->strideAt(0), output->strideAt(1), output->strideAt(2), output->strideAt(3)); +// if (res != 0) +// throw nd4j::cuda_exception::build("cudnnSetTensor4dDescriptorEx dst failed", res); + +// // TODO: workspace algorithms are supposed to be faster, so we should use it here if we have enough memory +// cudnnConvolutionFwdAlgo_t algo; +// res = cudnnGetConvolutionForwardAlgorithm(*handle, src, wght, cdc, dst, CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, 0, &algo); +// if (res != 0) +// throw nd4j::cuda_exception::build("cudnnGetConvolutionForwardAlgorithm failed", res); + +// // TODO: should be float if dtype is half/float, and double otherwise +// float alpha = 1.0f; +// float beta = 0.0f; +// res = cudnnConvolutionForward(*handle, &alpha, src, input->specialBuffer(), wght, weights->specialBuffer(), cdc, algo, nullptr, 0, &beta, dst, output->specialBuffer()); +// if (res != 0) +// throw nd4j::cuda_exception::build("cudnnConvolutionForward failed", res); + + +// if (bias != nullptr) { +// cudnnTensorDescriptor_t bs; +// cudnnCreateTensorDescriptor(&bs); +// if (isNCHW) { +// res = cudnnSetTensor4dDescriptor(bs, CUDNN_TENSOR_NCHW, dtype, 1, bias->lengthOf(), 1, 1); +// if (res != 0) +// throw nd4j::cuda_exception::build("cudnnSetTensor4dDescriptorEx bias NHWC failed", res); +// } else { +// res = cudnnSetTensor4dDescriptor(bs, CUDNN_TENSOR_NHWC, dtype, 1, 1, 1, bias->lengthOf()); +// if (res != 0) +// throw nd4j::cuda_exception::build("cudnnSetTensor4dDescriptorEx bias NHWC failed", res); +// } + +// res = cudnnAddTensor(*handle, &alpha, bs, bias->specialBuffer(), &alpha, dst, output->specialBuffer()); +// if (res != 0) +// throw nd4j::cuda_exception::build("cudnnAddTensor failed", res); +// } + + +// NDArray::registerSpecialUse({output}, {input, weights, bias}); + +// return Status::OK(); +// } + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu new file mode 100644 index 000000000..9d30ff04c --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu @@ -0,0 +1,453 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + +////////////////////////////////////////////////////////////////////////// +static void conv3dCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int paddingMode, const bool isNCDHW) { + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: can't set stream for cuDNN", err); + + const std::vector pads = {pD, pH, pW}; + const std::vector filtStrides = {sD, sH, sW}; + const std::vector dilations = {dD, dH, dW}; + + const std::vector xShape = {bS, iC, iD, iH, iW}; + const std::vector zShape = {bS, oC, oD, oH, oW}; + const std::vector wShape = {oC, iC, kD, kH, kW}; + const std::vector bShape = {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)}; + + const std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; + const std::vector zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape.data(), xStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); + + // weights descriptor + cudnnFilterDescriptor_t w; + cudnnCreateFilterDescriptor(&w); + err = cudnnSetFilterNdDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, numDims, wShape.data()); + if(err != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudnnSetFilterNdDescriptor failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if(output->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape.data()); + else + err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape.data(), zStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolutionNdDescriptor(conv, numDims-2, pads.data(), filtStrides.data(), dilations.data(), CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); + if (err != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudnnSetConvolutionNdDescriptor failed", err); + + // algorithm description + cudnnConvolutionFwdAlgo_t algo; + err = cudnnGetConvolutionForwardAlgorithm(*handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); + if (err != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); + + // allocate auxiliary device memory, abbreviation ws means workspace + size_t wsSize; + err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, &wsSize); + if (err != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); + void* wsData; + auto cudaErr = cudaMalloc(&wsData, wsSize); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudaMalloc for auxiliary workspace memory failed", cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, weights, bias}); + + // run calculation + err = cudnnConvolutionForward(*handle, alpha, x, input->getSpecialBuffer(), w, weights->getSpecialBuffer(), conv, algo, wsData, wsSize, beta, z, output->specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudnnConvolutionForward failed", err); + + // add bias if it is present + if (bias != nullptr) { + + cudnnTensorDescriptor_t b; + cudnnCreateTensorDescriptor(&b); + err = cudnnSetTensorNdDescriptorEx(b, format, cudnnDataType(bias->dataType()), numDims, bShape.data()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err); + err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", err); + } + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv3dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + cudaErr = cudaFree(wsData); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv3dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); + + NDArray::registerSpecialUse({output}, {input, weights, bias}); +} + +////////////////////////////////////////////////////////////////////////// +static void conv3dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* weights, const NDArray* gradO, + NDArray* gradI, NDArray* gradW, NDArray* gradB, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int paddingMode, const bool isNCDHW) { + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: can't set stream for cuDNN", err); + + const std::vector pads = {pD, pH, pW}; + const std::vector filtStrides = {sD, sH, sW}; + const std::vector dilations = {dD, dH, dW}; + + const std::vector xShape = {bS, iC, iD, iH, iW}; + const std::vector dzShape = {bS, oC, oD, oH, oW}; + const std::vector wShape = {oC, iC, kD, kH, kW}; + const std::vector dbShape = {1, (int)(isNCDHW ? oC : 1), 1, 1, (int)(isNCDHW ? 1 : oC)}; + + const std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; + const std::vector dxStrides = {(int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), (int)gradI->strideAt(3), (int)gradI->strideAt(4)}; + const std::vector dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape.data(), xStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if(gradO->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape.data()); + else + err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape.data(), dzStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if(gradI->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dx, format, cudnnDataType(gradI->dataType()), numDims, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dx, cudnnDataType(gradI->dataType()), numDims, xShape.data(), dxStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI failed", err); + + // gradW descriptor + cudnnFilterDescriptor_t dw; + cudnnCreateFilterDescriptor(&dw); + err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, numDims, wShape.data()); + if(err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolutionNdDescriptor(conv, numDims-2, pads.data(), filtStrides.data(), dilations.data(), CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnSetConvolutionNdDescriptor failed", err); + + // gradW algorithm description + cudnnConvolutionBwdFilterAlgo_t algoGradW; + err = cudnnGetConvolutionBackwardFilterAlgorithm(*handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, &algoGradW); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", err); + + // gradI algorithm description + cudnnConvolutionBwdDataAlgo_t algoGradI; + err = cudnnGetConvolutionBackwardDataAlgorithm(*handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, &algoGradI); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); + + // allocate auxiliary device memory for gradW calculation, abbreviation ws means workspace + size_t wsGradWSize; + err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, algoGradW, &wsGradWSize); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", err); + void* wsGradWData; + auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData failed", cudaErr); + + // allocate auxiliary device memory for gradI calculation, abbreviation ws means workspace + size_t wsGradISize; + err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, algoGradI, &wsGradISize); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", err); + void* wsGradIData; + cudaErr = cudaMalloc(&wsGradIData, wsGradISize); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData failed", cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + + // run calculation for gradB (if not nullptr) + if(gradB != nullptr) { + + cudnnTensorDescriptor_t db; + cudnnCreateTensorDescriptor(&db); + err = cudnnSetTensorNdDescriptorEx(db, format, cudnnDataType(gradB->dataType()), numDims, dbShape.data()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor for gradB failed", err); + + err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnConvolutionBackwardBias failed", err); + } + + // run calculation for gradW + err = cudnnConvolutionBackwardFilter(*handle, alpha, x, input->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, gradW->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); + + // run calculation for gradI + err = cudnnConvolutionBackwardData(*handle, alpha, dw, weights->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), conv, algoGradI, wsGradIData, wsGradISize, beta, dx, gradI->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudnnConvolutionBackwardData failed", err); + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + cudaErr = cudaFree(wsGradWData); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData failed", cudaErr); + cudaErr = cudaFree(wsGradIData); + if (cudaErr != 0) throw nd4j::cuda_exception::build("conv3dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData failed", cudaErr); + + NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, "CONV3D CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, "CONV3D CUDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); + + std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} + newWeights->assign(weights->permute({4,3,0,1,2})); // permute weights (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) + + NDArray* newInput = input; + NDArray* newGradI = nullptr; + if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings + checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); + + conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW); + + if(newInput != input) + delete newInput; + + delete newWeights; + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + + const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + + return paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); + + REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV3D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if(bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); + + NDArray* newGradW = new NDArray(gradW->ordering(), {oC, iC, kD, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} + NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); + + newWeights->assign(weights->permute({4,3,0,1,2})); // permute weights (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) + + NDArray* newInput = input; + NDArray* newGradI = gradI; + if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings + checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); + + conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW); + + newGradW->permutei({2,3,4,1,0}); // [oC, iC, kD, kH, kW] -> [kD, kH, kW, iC, oC] + gradW->assign(newGradW); + + if(newInput != input) { + + if(isNCDHW) + gradI->assign((*newGradI)({0,0, 0,0, 0,gradI->sizeAt(2), 0,gradI->sizeAt(3), 0,gradI->sizeAt(4)})); + else + gradI->assign((*newGradI)({0,0, 0,gradI->sizeAt(1), 0,gradI->sizeAt(2), 0,gradI->sizeAt(3), 0,0})); + + delete newInput; + delete newGradI; + } + + delete newWeights; + delete newGradW; + + return Status::OK(); +} + +PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + + const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; + const bool badGradOType = gradO->dataType() != DataType::DOUBLE && gradO->dataType() != DataType::FLOAT32 && gradO->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + + return isNCDHW && paddingMode != 2 && !badInputType && !badWeightsType && !badGradOType && !badBiasType; +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h new file mode 100644 index 000000000..bdff86e24 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h @@ -0,0 +1,158 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_CUDNNUTILS_H +#define SD_CUDNNUTILS_H + +#include +#include +#include +#include +#include +#include + +#include + +namespace nd4j { +namespace ops { +namespace platforms { + + DECLARE_PLATFORM(conv2d, ENGINE_CUDA); + DECLARE_PLATFORM(conv2d_bp, ENGINE_CUDA); + + DECLARE_PLATFORM(conv3dnew, ENGINE_CUDA); + DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CUDA); + + DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CUDA); + DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CUDA); + + DECLARE_PLATFORM(batchnorm, ENGINE_CUDA); + DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA); + +////////////////////////////////////////////////////////////////////////// +FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) { + switch (dataType) { + case nd4j::DataType::FLOAT32: + return CUDNN_DATA_FLOAT; + case nd4j::DataType::DOUBLE: + return CUDNN_DATA_DOUBLE; + case nd4j::DataType::HALF: + return CUDNN_DATA_HALF; + case nd4j::DataType::INT32: + return CUDNN_DATA_INT32; + case nd4j::DataType::INT8: + return CUDNN_DATA_INT8; + default: + throw datatype_exception::build("Unsupported data type", dataType); + } +} + +////////////////////////////////////////////////////////////////////////// +FORCEINLINE void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, + const int iH, const int iW, + const int oH, const int oW, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW) { + + const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); + const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); + + const bool isPHasymm = pH != (pHsum - pH); + const bool isPWasymm = pW != (pWsum - pW); + + if(!isPHasymm && !isPWasymm) + return; + + std::vector newShape = input->getShapeAsVector(); + + const int iHposition = isNCHW ? 2 : 1; + + if(isPHasymm) + newShape[iHposition] += 1; + if(isPWasymm) + newShape[iHposition + 1] += 1; + + NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); + + if(isNCHW) + (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input); + else + (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input); + + input = newInput; + + if(gradI != nullptr) + gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); +} + + +////////////////////////////////////////////////////////////////////////// +FORCEINLINE void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, + const int iD, const int iH, const int iW, + const int oD, const int oH, const int oW, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW) { + + const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD); + const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); + const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); + + const bool isPDasymm = pD != (pDsum - pD); + const bool isPHasymm = pH != (pHsum - pH); + const bool isPWasymm = pW != (pWsum - pW); + + if(!isPDasymm && !isPHasymm && !isPWasymm) + return; + + std::vector newShape = input->getShapeAsVector(); + + const int iDposition = isNCDHW ? 2 : 1; + + if(isPDasymm) + newShape[iDposition] += 1; + if(isPHasymm) + newShape[iDposition + 1] += 1; + if(isPWasymm) + newShape[iDposition + 2] += 1; + + NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); + + if(isNCDHW) + (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input); + else + (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input); + + input = newInput; + + if(gradI != nullptr) + gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); +} + +} +} +} + +#endif //SD_CUDNNUTILS_H diff --git a/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu new file mode 100644 index 000000000..d328fa92b --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu @@ -0,0 +1,443 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +static void depthwiseConv2dCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const int paddingMode, const bool isNCHW) { + + // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) + + // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc + // weights [iC, mC, kH, kW], mkl doesn't support this format, so we'll make permute + // bias [oC], may be nullptr + // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc + // oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(1); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); + + // weights descriptor + cudnnFilterDescriptor_t w; + cudnnCreateFilterDescriptor(&w); + err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, iC, mC, kH, kW); + if(err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if(output->ews() == 1) + err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetConvolution2dDescriptor failed", err); + err = cudnnSetConvolutionGroupCount(conv, iC); // set number of groups (depthwise mode) in description of convolution, groupCount == iC + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetConvolutionGroupCount failed", err); + + // algorithm description + cudnnConvolutionFwdAlgo_t algo; + err = cudnnGetConvolutionForwardAlgorithm(*handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); + + // allocate auxiliary device memory, abbreviation ws means workspace + size_t wsSize; + err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, &wsSize); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); + void* wsData; + auto cudaErr = cudaMalloc(&wsData, wsSize); + if (cudaErr != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudaMalloc for auxiliary workspace memory failed", cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, weights, bias}); + + // run calculation + err = cudnnConvolutionForward(*handle, alpha, x, input->getSpecialBuffer(), w, weights->getSpecialBuffer(), conv, algo, wsData, wsSize, beta, z, output->specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudnnConvolutionForward failed", err); + + // add bias if it is present + if (bias != nullptr) { + + cudnnTensorDescriptor_t b; + cudnnCreateTensorDescriptor(&b); + err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); + err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err); + } + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("depthwiseConv2dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + cudaErr = cudaFree(wsData); + if (cudaErr != 0) throw nd4j::cuda_exception::build("depthwiseConv2dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); + + NDArray::registerSpecialUse({output}, {input, weights, bias}); +} + +////////////////////////////////////////////////////////////////////////// +static void depthwiseConv2dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* weights, const NDArray* gradO, + NDArray* gradI, NDArray* gradW, NDArray* gradB, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const int paddingMode, const bool isNCHW) { + + // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) + + // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc + // weights, gradW [iC, mC, kH, kW], mkl doesn't support this format, so we'll make permute + // gradB [oC], may be nullptr + // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc + // oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(1); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if(gradO->ews() == 1) + err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if(gradI->ews() == 1) + err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradI failed", err); + + // gradW descriptor + cudnnFilterDescriptor_t dw; + cudnnCreateFilterDescriptor(&dw); + err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, iC, mC, kH, kW); + if(err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetConvolution2dDescriptor failed", err); + err = cudnnSetConvolutionGroupCount(conv, iC); // set number of groups (depthwise mode) in description of convolution, groupCount == iC + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetConvolutionGroupCount failed", err); + + // gradW algorithm description + cudnnConvolutionBwdFilterAlgo_t algoGradW; + err = cudnnGetConvolutionBackwardFilterAlgorithm(*handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, &algoGradW); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", err); + + // gradI algorithm description + cudnnConvolutionBwdDataAlgo_t algoGradI; + err = cudnnGetConvolutionBackwardDataAlgorithm(*handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, &algoGradI); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); + + // allocate auxiliary device memory for gradW calculation, abbreviation ws means workspace + size_t wsGradWSize; + err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, algoGradW, &wsGradWSize); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", err); + void* wsGradWData; + auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); + if (cudaErr != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData failed", cudaErr); + + // allocate auxiliary device memory for gradI calculation, abbreviation ws means workspace + size_t wsGradISize; + err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, algoGradI, &wsGradISize); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", err); + void* wsGradIData; + cudaErr = cudaMalloc(&wsGradIData, wsGradISize); + if (cudaErr != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData failed", cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + + // run calculation for gradB (if not nullptr) + if(gradB != nullptr) { + cudnnTensorDescriptor_t db; + cudnnCreateTensorDescriptor(&db); + err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); + + err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardBias failed", err); + } + + // run calculation for gradW + err = cudnnConvolutionBackwardFilter(*handle, alpha, x, input->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, gradW->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); + + // run calculation for gradI + err = cudnnConvolutionBackwardData(*handle, alpha, dw, weights->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), conv, algoGradI, wsGradIData, wsGradISize, beta, dx, gradI->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardData failed", err); + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("depthwiseConv2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + cudaErr = cudaFree(wsGradWData); + if (cudaErr != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData failed", cudaErr); + cudaErr = cudaFree(wsGradIData); + if (cudaErr != 0) throw nd4j::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData failed", cudaErr); + + NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, "DEPTHWISECONV2D CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, "DEPTHWISECONV2D CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + std::vector expectedWeightsShape = {kH, kW, iC, mC}; + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support format {oC, iC/groupCount, kH, kW} + newWeights->assign(weights->permute({2,3,0,1})); // assign permuted weights (kH, kW, iC, mC --> iC, mC, kH, kW) + + NDArray* newInput = input; + NDArray* newGradI = nullptr; + if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); + + depthwiseConv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW); + + if(newInput != input) + delete newInput; + + delete newWeights; + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + + const int mC = weights->sizeAt(3); + + const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + + return mC == 1 && paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + int trueoH, trueoW; // correct output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); + std::vector expectedWeightsShape = {kH, kW, iC, mC}; + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if(bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + + NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support format {oC, iC/groupCount, kH, kW} + NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); + + newWeights->assign(weights->permute({2,3,0,1})); // assign permuted weights (kH, kW, iC, mC --> iC, mC, kH, kW) + + NDArray* newInput = input; + NDArray* newGradI = gradI; + if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); + + depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); + + newGradW->permutei({2,3,0,1}); // [iC, mC, kH, kW] -> [kH, kW, iC, mC] + gradW->assign(newGradW); + + if(newInput != input) { + + if(isNCHW) + gradI->assign((*newGradI)({0,0, 0,0, 0,gradI->sizeAt(2), 0,gradI->sizeAt(3)})); + else + gradI->assign((*newGradI)({0,0, 0,gradI->sizeAt(1), 0,gradI->sizeAt(2), 0,0})); + + delete newInput; + delete newGradI; + } + + delete newWeights; + delete newGradW; + + return Status::OK(); +} + +PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + + const int mC = weights->sizeAt(3); + + const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; + const bool badGradOType = gradO->dataType() != DataType::DOUBLE && gradO->dataType() != DataType::FLOAT32 && gradO->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + + return mC == 1 && isNCHW && paddingMode != 2 && !badInputType && !badWeightsType && !badGradOType && !badBiasType; +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp index 9a3b2916b..bf614bfab 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp @@ -28,11 +28,12 @@ #include using namespace dnnl; +using namespace samediff; namespace nd4j { namespace ops { namespace platforms { - PLATFORM_IMPL(avgpool2d) { + PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", @@ -128,7 +129,7 @@ namespace nd4j { return Status::OK(); } - PLATFORM_CHECK(avgpool2d) { + PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp index 428bd6042..af1fd04fd 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp @@ -32,7 +32,7 @@ using namespace dnnl; namespace nd4j { namespace ops { namespace platforms { - PLATFORM_IMPL(avgpool2d_bp) { + PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE( 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto gradO = INPUT_VARIABLE( @@ -138,7 +138,7 @@ namespace nd4j { return Status::OK(); } - PLATFORM_CHECK(avgpool2d_bp) { + PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp index 22ace87de..2456625ef 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp @@ -32,7 +32,7 @@ using namespace dnnl; namespace nd4j { namespace ops { namespace platforms { - PLATFORM_IMPL(avgpool3dnew) { + PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) { auto input = INPUT_VARIABLE( 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto output = OUTPUT_VARIABLE( @@ -130,7 +130,7 @@ namespace nd4j { return Status::OK(); } - PLATFORM_CHECK(avgpool3dnew) { + PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp index 0c52608a0..3fd8ab293 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp @@ -31,7 +31,7 @@ using namespace dnnl; namespace nd4j { namespace ops { namespace platforms { - PLATFORM_IMPL(avgpool3dnew_bp) { + PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE( 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto gradO = INPUT_VARIABLE( @@ -143,7 +143,7 @@ namespace nd4j { return Status::OK(); } - PLATFORM_CHECK(avgpool3dnew_bp) { + PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index e66589b0a..8974cef14 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -339,43 +339,43 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // x - mean NDArray xMinusMean(x); // empty array with same shape as x - const_cast(x)->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean); + const_cast(x)->applyBroadcast(nd4j::broadcast::Subtract, axes, *mean, xMinusMean); // stdInv NDArray stdInv = *variance + epsilon; - stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) - stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 + stdInv.applyTransform(transform::Reciprocal, stdInv); // 1 / (variance + epsilon) + stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 // dfdm / N - auto dfdm = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes); + auto dfdm = dLdO->reduceAlongDimension(nd4j::reduce::Sum, excludedAxes); dfdm *= stdInv; dfdm *= -Ninv; // dvdm / 2 NDArray dvdm(mean); // empty array with same shape as mean - xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, &dvdm, excludedAxes); + xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, dvdm, excludedAxes); dvdm *= -Ninv; // (2/N)*dfdv NDArray dfdv(variance); // empty array with same shape as variance - (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, &dfdv, excludedAxes); + (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, dfdv, excludedAxes); dfdv *= stdInv*stdInv*stdInv; dfdv *= -Ninv; // dvdm/2 + (x - m) - xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, &dvdm); + xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dvdm, xMinusMean); // dfdv * (dvdm/2 + (x - m)) - xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, &dfdv); + xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, dfdv, xMinusMean); // add dfdm / N - xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, &dfdm); + xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dfdm, xMinusMean); // * gamma auto gamma = (*weights)({0,1, 0,0}); - xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, &gamma); + xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, gamma, xMinusMean); *dLdI += xMinusMean; } -PLATFORM_IMPL(batchnorm) { +PLATFORM_IMPL(batchnorm, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw auto mean = INPUT_VARIABLE(1); // [c] @@ -455,7 +455,7 @@ PLATFORM_IMPL(batchnorm) { } ////////////////////////////////////////////////////////////////////////// -PLATFORM_CHECK(batchnorm) { +PLATFORM_CHECK(batchnorm, ENGINE_CPU) { // we don't want to use mkldnn if cpu doesn't support avx/avx2 // if (::optimalLevel() < 2) // return false; @@ -632,7 +632,7 @@ PLATFORM_CHECK(batchnorm) { ////////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(batchnorm_bp) { +PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) { NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw NDArray* mean = INPUT_VARIABLE(1); // [c] @@ -735,7 +735,7 @@ PLATFORM_IMPL(batchnorm_bp) { } ////////////////////////////////////////////////////////////////////////// -PLATFORM_CHECK(batchnorm_bp) { +PLATFORM_CHECK(batchnorm_bp, ENGINE_CPU) { NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw NDArray* mean = INPUT_VARIABLE(1); // [c] NDArray* variance = INPUT_VARIABLE(2); // [c] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index a01679740..ba1711032 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -113,7 +113,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con } ////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(conv2d) { +PLATFORM_IMPL(conv2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] @@ -137,7 +137,7 @@ PLATFORM_IMPL(conv2d) { return Status::OK(); } -PLATFORM_CHECK(conv2d) { +PLATFORM_CHECK(conv2d, ENGINE_CPU) { // we don't want to use mkldnn if cpu doesn't support avx/avx2 if (::optimalLevel() < 2) return false; @@ -151,7 +151,7 @@ PLATFORM_CHECK(conv2d) { } ////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(conv2d_bp) { +PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] @@ -328,7 +328,7 @@ PLATFORM_IMPL(conv2d_bp) { return Status::OK(); } -PLATFORM_CHECK(conv2d_bp) { +PLATFORM_CHECK(conv2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 1e28e76a5..0a79df793 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -34,7 +34,7 @@ namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(conv3dnew) { +PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { auto input = INPUT_VARIABLE( 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always @@ -150,7 +150,7 @@ PLATFORM_IMPL(conv3dnew) { return Status::OK(); } -PLATFORM_CHECK(conv3dnew) { +PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { // we don't want to use mkldnn if cpu doesn't support avx/avx2 if (::optimalLevel() < 2) return false; @@ -167,7 +167,7 @@ PLATFORM_CHECK(conv3dnew) { ////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(conv3dnew_bp) { +PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE( 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE( @@ -374,7 +374,7 @@ PLATFORM_IMPL(conv3dnew_bp) { return Status::OK(); } -PLATFORM_CHECK(conv3dnew_bp) { +PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE( 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE( diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index ced37aea8..6db569eec 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -34,13 +34,13 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int isSameMode) { + const int paddingMode) { - // input [bS, iH, iW, iC] nchw, mkl doesn't support format nhwc + // input [bS, iC, iH, iW] nchw, mkl doesn't support format nhwc // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC] // bias [oC], may be nullptr - // output [bS, oH, oW, oC] nchw, mkl doesn't support format nhwc + // output [bS, oC, oH, oW] nchw, mkl doesn't support format nhwc int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes @@ -179,12 +179,12 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N ////////////////////////////////////////////////////////////////////////// static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int isSameMode) { + const int paddingMode) { - // input and gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format + // input and gradI [bS, iC, iH, iW], mkl doesn't support ndhwc format // weights and gradW [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC] // gradB [oC], may be nullptr - // gradO [bS, oH, oW, oC] + // gradO [bS, oC, oH, oW] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes @@ -349,7 +349,7 @@ static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, ////////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(deconv2d) { +PLATFORM_IMPL(deconv2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always @@ -368,19 +368,19 @@ PLATFORM_IMPL(deconv2d) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - std::vector expectedWeightsShape = {kH, kW, oC, iC}; + std::vector expectedWeightsShape = {kH, kW, oC, iC}; REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - if(isSameMode){ // SAME + if(paddingMode){ // SAME //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); } @@ -394,7 +394,7 @@ PLATFORM_IMPL(deconv2d) { output = new NDArray(output->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] } - deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode); + deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode); delete weights; @@ -406,7 +406,7 @@ PLATFORM_IMPL(deconv2d) { return Status::OK(); } -PLATFORM_CHECK(deconv2d) { +PLATFORM_CHECK(deconv2d, ENGINE_CPU) { // we don't want to use mkldnn if cpu doesn't support avx/avx2 // if (::optimalLevel() < 2) // return false; @@ -419,14 +419,14 @@ PLATFORM_CHECK(deconv2d) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME const DataType xType = input->dataType(); const DataType wType = weights->dataType(); const DataType zType = output->dataType(); const DataType bType = bias != nullptr ? bias->dataType() : zType; - return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) && + return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && ( (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) @@ -435,7 +435,7 @@ PLATFORM_CHECK(deconv2d) { ////////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(deconv2d_bp) { +PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always @@ -459,7 +459,7 @@ PLATFORM_IMPL(deconv2d_bp) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; @@ -467,7 +467,7 @@ PLATFORM_IMPL(deconv2d_bp) { ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); + ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector expectedWeightsShape = {kH, kW, oC, iC}; @@ -476,7 +476,7 @@ PLATFORM_IMPL(deconv2d_bp) { if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - if(isSameMode){ // SAME + if(paddingMode){ // SAME //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); } @@ -492,7 +492,7 @@ PLATFORM_IMPL(deconv2d_bp) { gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] } - deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode); + deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode); delete weights; delete gradW; @@ -506,7 +506,7 @@ PLATFORM_IMPL(deconv2d_bp) { return Status::OK(); } -PLATFORM_CHECK(deconv2d_bp) { +PLATFORM_CHECK(deconv2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] @@ -518,7 +518,7 @@ PLATFORM_CHECK(deconv2d_bp) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME const DataType xType = input->dataType(); const DataType wType = weights->dataType(); @@ -528,7 +528,7 @@ PLATFORM_CHECK(deconv2d_bp) { const DataType gradWType = gradW->dataType(); const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; - return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); + return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp index fac53e877..90ddb828e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp @@ -145,7 +145,7 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad ////////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(deconv2d_tf) { +PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always @@ -222,7 +222,7 @@ PLATFORM_IMPL(deconv2d_tf) { return Status::OK(); } -PLATFORM_CHECK(deconv2d_tf) { +PLATFORM_CHECK(deconv2d_tf, ENGINE_CPU) { auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index 7259ea0db..a678e0185 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -34,8 +34,7 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const int isSameMode) { + const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { // input [bS, iD, iH, iW, iC] ncdhw, mkl doesn't support format ndhwc // weights [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC] @@ -182,8 +181,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N ////////////////////////////////////////////////////////////////////////// static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const int isSameMode) { + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW) { // input and gradI [bS, iD, iH, iW, iC], mkl doesn't support ndhwc format // weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC] @@ -359,7 +360,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, ////////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(deconv3d) { +PLATFORM_IMPL(deconv3d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always @@ -408,7 +409,7 @@ PLATFORM_IMPL(deconv3d) { output = new NDArray(output->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] } - deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode); + deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW); delete weights; @@ -420,7 +421,7 @@ PLATFORM_IMPL(deconv3d) { return Status::OK(); } -PLATFORM_CHECK(deconv3d) { +PLATFORM_CHECK(deconv3d, ENGINE_CPU) { // we don't want to use mkldnn if cpu doesn't support avx/avx2 // if (::optimalLevel() < 2) // return false; @@ -450,7 +451,7 @@ PLATFORM_CHECK(deconv3d) { ////////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(deconv3d_bp) { +PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always @@ -509,7 +510,7 @@ PLATFORM_IMPL(deconv3d_bp) { gradO = new NDArray(gradO->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] } - deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode); + deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW); delete weights; delete gradW; @@ -524,7 +525,7 @@ PLATFORM_IMPL(deconv3d_bp) { } -PLATFORM_CHECK(deconv3d_bp) { +PLATFORM_CHECK(deconv3d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp new file mode 100644 index 000000000..f3b745d09 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -0,0 +1,505 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include +#include +#include +#include "mkldnnUtils.h" + +using namespace dnnl; + +namespace nd4j { +namespace ops { +namespace platforms { + +////////////////////////////////////////////////////////////////////////// +static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, + const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, + const int paddingMode, const bool isNCHW) { + + // mkl supports only following case: mC = 1, oC = iC + + // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given + // weights [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute + // bias [oC], may be nullptr + // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc + // oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = { sH, sW }; + dnnl::memory::dims padding = { pH, pW }; + dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; + dnnl::memory::dims dilation = { dH-1, dW-1}; + + // input type + dnnl::memory::data_type xType; + if(input->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if(input->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else if(input->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = xType; + if(xType == dnnl::memory::data_type::u8) + wType = dnnl::memory::data_type::s8; + + // output and bias type (have the same types) + dnnl::memory::data_type zType; + if(output->dataType() == DataType::FLOAT32) + zType = dnnl::memory::data_type::f32; + else if(output->dataType() == DataType::HALF) + zType = dnnl::memory::data_type::f16; + else if(output->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if(output->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + else + zType = dnnl::memory::data_type::s32; + + dnnl::memory::format_tag xzFrmat = dnnl::memory::format_tag::nchw; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); + x_user_md.data.format_kind = dnnl_blocked; // overrides format NHWC -> NCHW + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2); + + // weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // permute + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); + w_user_md.data.format_desc.blocking.strides[2] = 0; + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1); + + // bias + dnnl::memory::desc b_mkl_md; + if(bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat); + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 : 3); + z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1); + z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); + const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; + + // weights + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); + const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); + auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; + if (wReorder) + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + + // bias + if(bias != nullptr) { + auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer()); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::convolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + // shape::printArray(z_mkl_mem.map_data(),8); +} + +////////////////////////////////////////////////////////////////////////// +static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, + const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, + const int paddingMode, const bool isNCHW) { + + // mkl supports only following case: mC = 1, oC = iC + + // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given + // weights, gradW [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute + // gradB [oC], may be nullptr + // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc + // oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); + + const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = { sH, sW }; + dnnl::memory::dims padding = { pH, pW }; + dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; + dnnl::memory::dims dilation = { dH-1, dW-1}; + + // input type + dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; + // weights type + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; + // gradO type + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; + // gradI type + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; + // gradW type + dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; + // gradB type + dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; + + dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2); + + // weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // permute + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); + w_user_md.data.format_desc.blocking.strides[2] = 0; + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1); + + // gradO + dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format + gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); + gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 : 3); + gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1); + gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2); + + // gradI + dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format + gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); + gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 : 3); + gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1); + gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2); + + // gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; + dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // permute + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); + gradW_user_md.data.format_desc.blocking.strides[2] = 0; + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(0); + gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(1); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if(gradB != nullptr) + gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); + const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; + + // weights + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); + const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); + auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; + if (wReorder) + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + + // gradO + auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); + const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + if (gradOReorder) + dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem; + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); + const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + // gradW + auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer()); + const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); + auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; + args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + + // gradB + if(gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + // run backward weights calculations + dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); + + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); + if (gradWReorder) + dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); +} + + +////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + std::vector expectedWeightsShape = {kH, kW, iC, mC}; + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) { + // we don't want to use mkldnn if cpu doesn't support avx/avx2 + if (::optimalLevel() < 2) + return false; + + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; + + auto output = INPUT_VARIABLE(0); + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType zType = output->dataType(); + const DataType bType = bias != nullptr ? bias->dataType() : zType; + + const int mC = weights->sizeAt(3); + + return block.isUseMKLDNN() && mC == 1 && + ( + (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || + (xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF) || + ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) + ); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + int trueoH, trueoW; // correct output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); + std::vector expectedWeightsShape = {kH, kW, iC, mC}; + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if(bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType gradOType = gradO->dataType(); + + const DataType gradIType = gradI->dataType(); + const DataType gradWType = gradW->dataType(); + const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; + + const int mC = weights->sizeAt(3); + + return block.isUseMKLDNN() && mC == 1 && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp index ecd8b4c1a..a0f2f6151 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp @@ -32,7 +32,7 @@ using namespace dnnl; namespace nd4j { namespace ops { namespace platforms { - PLATFORM_IMPL(lrn) { + PLATFORM_IMPL(lrn, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -82,7 +82,7 @@ namespace nd4j { return Status::OK(); }; - PLATFORM_CHECK(lrn) { + PLATFORM_CHECK(lrn, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index 7417653b3..3371b16ad 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -365,7 +365,7 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* } ////////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(lstmLayer) { +PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) @@ -493,7 +493,7 @@ PLATFORM_IMPL(lstmLayer) { return Status::OK(); } -PLATFORM_CHECK(lstmLayer) { +PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { const auto hasBiases = B_ARG(0); // indicates whether biases array is provided const auto hasInitH = B_ARG(2); // indicates whether initial output is provided const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp index 03008fbc6..975cf7fe1 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp @@ -32,7 +32,7 @@ using namespace dnnl; namespace nd4j { namespace ops { namespace platforms { - PLATFORM_IMPL(maxpool2d) { + PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", @@ -134,7 +134,7 @@ namespace nd4j { return Status::OK(); } - PLATFORM_CHECK(maxpool2d) { + PLATFORM_CHECK(maxpool2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp index e50bef362..686bdc7fb 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp @@ -32,7 +32,7 @@ using namespace dnnl; namespace nd4j { namespace ops { namespace platforms { - PLATFORM_IMPL(maxpool2d_bp) { + PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE( 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto gradO = INPUT_VARIABLE( @@ -163,7 +163,7 @@ namespace nd4j { return Status::OK(); } - PLATFORM_CHECK(maxpool2d_bp) { + PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp index 6f132bb56..604bdcb6b 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp @@ -31,7 +31,7 @@ using namespace dnnl; namespace nd4j { namespace ops { namespace platforms { - PLATFORM_IMPL(maxpool3dnew) { + PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { auto input = INPUT_VARIABLE( 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto output = OUTPUT_VARIABLE( @@ -140,7 +140,7 @@ namespace nd4j { return Status::OK(); } - PLATFORM_CHECK(maxpool3dnew) { + PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp index 4f51d6633..b684df1bb 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp @@ -31,7 +31,7 @@ using namespace dnnl; namespace nd4j { namespace ops { namespace platforms { - PLATFORM_IMPL(maxpool3dnew_bp) { + PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE( 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto gradO = INPUT_VARIABLE( @@ -170,7 +170,7 @@ namespace nd4j { return Status::OK(); } - PLATFORM_CHECK(maxpool3dnew_bp) { + PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 6274a645f..b55103a02 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -29,6 +29,8 @@ #include #include +using namespace samediff; + namespace nd4j{ namespace ops { @@ -36,47 +38,51 @@ namespace nd4j{ /** * Here we actually declare our platform helpers */ - DECLARE_PLATFORM(conv2d); + DECLARE_PLATFORM(conv2d, ENGINE_CPU); - DECLARE_PLATFORM(conv2d_bp); + DECLARE_PLATFORM(conv2d_bp, ENGINE_CPU); - DECLARE_PLATFORM(avgpool2d); + DECLARE_PLATFORM(avgpool2d, ENGINE_CPU); - DECLARE_PLATFORM(avgpool2d_bp); + DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CPU); - DECLARE_PLATFORM(maxpool2d); + DECLARE_PLATFORM(maxpool2d, ENGINE_CPU); - DECLARE_PLATFORM(maxpool2d_bp); + DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CPU); - DECLARE_PLATFORM(conv3dnew); + DECLARE_PLATFORM(conv3dnew, ENGINE_CPU); - DECLARE_PLATFORM(conv3dnew_bp); + DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CPU); - DECLARE_PLATFORM(maxpool3dnew); + DECLARE_PLATFORM(maxpool3dnew, ENGINE_CPU); - DECLARE_PLATFORM(maxpool3dnew_bp); + DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CPU); - DECLARE_PLATFORM(avgpool3dnew); + DECLARE_PLATFORM(avgpool3dnew, ENGINE_CPU); - DECLARE_PLATFORM(avgpool3dnew_bp); + DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CPU); - DECLARE_PLATFORM(lrn); + DECLARE_PLATFORM(lrn, ENGINE_CPU); - DECLARE_PLATFORM(batchnorm); + DECLARE_PLATFORM(batchnorm, ENGINE_CPU); - DECLARE_PLATFORM(batchnorm_bp); + DECLARE_PLATFORM(batchnorm_bp, ENGINE_CPU); - DECLARE_PLATFORM(lstmLayer); + DECLARE_PLATFORM(lstmLayer, ENGINE_CPU); - DECLARE_PLATFORM(deconv2d); + DECLARE_PLATFORM(deconv2d, ENGINE_CPU); - DECLARE_PLATFORM(deconv2d_tf); + DECLARE_PLATFORM(deconv2d_tf, ENGINE_CPU); - DECLARE_PLATFORM(deconv3d); + DECLARE_PLATFORM(deconv3d, ENGINE_CPU); - DECLARE_PLATFORM(deconv2d_bp); + DECLARE_PLATFORM(deconv2d_bp, ENGINE_CPU); - DECLARE_PLATFORM(deconv3d_bp); + DECLARE_PLATFORM(deconv3d_bp, ENGINE_CPU); + + DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU); + + DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU); } } diff --git a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp index 0e9c99636..26cda74a4 100644 --- a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp @@ -55,4 +55,12 @@ namespace nd4j { return custom(nd4j::scalar::IGammac, nd4j::pairwise::IGammac, nd4j::broadcast::IGammac); } + + BroadcastOpsTuple BroadcastOpsTuple::Pow() { + return custom(nd4j::scalar::Pow, nd4j::pairwise::Pow, nd4j::broadcast::Pow); + } + BroadcastOpsTuple BroadcastOpsTuple::PowDerivative() { + return custom(nd4j::scalar::PowDerivative, nd4j::pairwise::PowDerivative, nd4j::broadcast::PowDerivative); + } + } diff --git a/libnd4j/include/ops/impl/gemm.cpp b/libnd4j/include/ops/impl/gemm.cpp index 74b832b4a..a81c12818 100644 --- a/libnd4j/include/ops/impl/gemm.cpp +++ b/libnd4j/include/ops/impl/gemm.cpp @@ -100,7 +100,7 @@ namespace nd4j { } if (beta != 0.0) { - C[zIdx] = static_cast(dot + beta * C[zIdx]); + C[zIdx] = static_cast(dot + static_cast(beta) * C[zIdx]); } else { C[zIdx] = static_cast(dot); } @@ -134,8 +134,8 @@ namespace nd4j { int aIdx = linearIndexC(M, N, r, 0); auto aX = aT + aIdx; - auto dot = nd4j::math::nd4j_dot(aX, y, lda) * alpha; - z[r] = beta == 0.0f ? dot : dot + beta * z[r]; + auto dot = nd4j::math::nd4j_dot(aX, y, lda) * static_cast(alpha); + z[r] = beta == 0.0f ? dot : dot + static_cast(beta) * z[r]; } }; samediff::Threads::parallel_for(func, 0, M); diff --git a/libnd4j/include/ops/impl/specials.cpp b/libnd4j/include/ops/impl/specials.cpp index 11cca1b15..ad7f4060d 100644 --- a/libnd4j/include/ops/impl/specials.cpp +++ b/libnd4j/include/ops/impl/specials.cpp @@ -175,13 +175,13 @@ void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, Nd4jPoint PRAGMA_OMP_SIMD for (uint64_t i = 0; i < length; i++) { - z[i] /= n; + z[i] /= static_cast(n); } auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { for (Nd4jLong ar = 1; ar < n; ar++) { - z[i] += x[ar][i] / n; + z[i] += x[ar][i] / static_cast(n); } } }; @@ -201,7 +201,7 @@ void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, Nd4jPoint auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { for (Nd4jLong ar = 0; ar < n; ar++) { - z[i] += x[ar][i] / n; + z[i] += x[ar][i] / static_cast(n); } } }; @@ -365,11 +365,11 @@ PRAGMA_OMP_SINGLE_ARGS(nowait) if (hasBit) { if (hasSign) - dz[(e - 4) * 16 + bitId] -= threshold; + dz[(e - 4) * 16 + bitId] -= static_cast(threshold); else - dz[(e - 4) * 16 + bitId] += threshold; + dz[(e - 4) * 16 + bitId] += static_cast(threshold); } else if (hasSign) { - dz[(e - 4) * 16 + bitId] -= threshold / 2; + dz[(e - 4) * 16 + bitId] -= static_cast(threshold / 2); } } } @@ -423,13 +423,13 @@ PRAGMA_OMP_SINGLE_ARGS(nowait) if (val < (T) 0.0f) { byte |= 1 << (bitId + 16); - dx[e] += threshold; + dx[e] += static_cast(threshold); } else { - dx[e] -= threshold; + dx[e] -= static_cast(threshold); } } else if (abs >= (T) threshold / (T) 2.0f && val < (T) 0.0f) { byte |= 1 << (bitId + 16); - dx[e] += threshold / 2; + dx[e] += static_cast(threshold / 2); retVal++; } diff --git a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp index 22bb87103..b4960bc90 100644 --- a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp @@ -1298,7 +1298,7 @@ namespace nd4j { strided = arr; } else { IndicesList indices({NDIndex::interval(0,131072), NDIndex::interval(0,1)}); - strided = arr->subarray(indices); //All rows, first column + strided = new NDArray(arr->subarray(indices)); //All rows, first column delete arr; } @@ -1322,7 +1322,7 @@ namespace nd4j { strided = arr; } else { IndicesList indices({NDIndex::interval(0,2*1024,2), NDIndex::all(), NDIndex::interval(0,1)}); - strided = arr->subarray(indices); + strided = new NDArray(arr->subarray(indices)); delete arr; } @@ -1358,7 +1358,7 @@ namespace nd4j { strided = arr; } else { IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)}); - strided = arr->subarray(indices); //All rows, first column + strided = new NDArray(arr->subarray(indices)); //All rows, first column delete arr; } @@ -1393,7 +1393,7 @@ namespace nd4j { strided = arr; } else { IndicesList indices({NDIndex::all(), NDIndex::point(0)}); - strided = arr->subarray(indices); //All rows, first column + strided = new NDArray(arr->subarray(indices)); //All rows, first column delete arr; } @@ -1418,7 +1418,7 @@ namespace nd4j { strided = arr; } else { IndicesList indices({NDIndex::all(), NDIndex::point(0)}); - strided = arr->subarray(indices); //All rows, first column + strided = new NDArray(arr->subarray(indices)); //All rows, first column delete arr; } @@ -1565,7 +1565,7 @@ namespace nd4j { int r = p.getIntParam("rowcol"); auto arr = NDArrayFactory::create_('c', {r, r+1}); IndicesList indices({NDIndex::all(), NDIndex::interval(0,r-1)}); - auto view = arr->subarray(indices); + auto view = new NDArray(arr->subarray(indices)); //nd4j_printf("VIEW ARRAY: rows=%lld, columns=%lld", view->sizeAt(0), view->sizeAt(1)); x.push_back(view); if(p.getIntParam("inplace") == 1){ diff --git a/libnd4j/include/platform_boilerplate.h b/libnd4j/include/platform_boilerplate.h index d3883bcf7..5c73a1b38 100644 --- a/libnd4j/include/platform_boilerplate.h +++ b/libnd4j/include/platform_boilerplate.h @@ -21,25 +21,37 @@ #ifndef SD_PLATFORM_BOILERPLATE_H #define SD_PLATFORM_BOILERPLATE_H - -#define DECLARE_PLATFORM(NAME) class ND4J_EXPORT PLATFORM_##NAME : public PlatformHelper {\ - public: \ - PLATFORM_##NAME() : PlatformHelper(#NAME) { } \ - bool isUsable(graph::Context &context) override; \ - Nd4jStatus invokeHelper(graph::Context &context) override; \ - }; - -#define PLATFORM_IMPL(NAME) struct ND4J_EXPORT __registratorPlatformHelper_##NAME { \ - __registratorPlatformHelper_##NAME() { \ - auto helper = new PLATFORM_##NAME(); \ - OpRegistrator::getInstance()->registerHelper(helper); \ - } \ - }; \ - static __registratorPlatformHelper_##NAME platformHelper_##NAME; \ - Nd4jStatus PLATFORM_##NAME::invokeHelper(nd4j::graph::Context &block) +#include -#define PLATFORM_CHECK(NAME) bool PLATFORM_##NAME::isUsable(graph::Context &block) + +#define CONCATP(A,B) A ##_##B + + +#define DECLARE_PLATFORM_F(NAME, ENGINE, CNAME) class ND4J_EXPORT PLATFORM_##CNAME : public PlatformHelper {\ + public: \ + PLATFORM_##CNAME() : PlatformHelper(#NAME, samediff::Engine::ENGINE) { } \ + bool isUsable(graph::Context &context) override; \ + Nd4jStatus invokeHelper(graph::Context &context) override; \ + }; + +#define DECLARE_PLATFORM(NAME, ENGINE) DECLARE_PLATFORM_F(NAME, ENGINE, NAME ##_## ENGINE) + +#define PLATFORM_IMPL_F(NAME, ENGINE, CNAME) struct ND4J_EXPORT __registratorPlatformHelper_##CNAME { \ + __registratorPlatformHelper_##CNAME() { \ + auto helper = new PLATFORM_##CNAME(); \ + OpRegistrator::getInstance()->registerHelper(helper); \ + } \ + }; \ + static __registratorPlatformHelper_##CNAME platformHelper_##CNAME; \ + Nd4jStatus PLATFORM_##CNAME::invokeHelper(nd4j::graph::Context &block) + + +#define PLATFORM_IMPL(NAME, ENGINE) PLATFORM_IMPL_F(NAME, ENGINE, NAME ##_## ENGINE) + + +#define PLATFORM_CHECK_F(NAME, ENGINE, CNAME) bool PLATFORM_##CNAME::isUsable(graph::Context &block) +#define PLATFORM_CHECK(NAME, ENGINE) PLATFORM_CHECK_F(NAME, ENGINE, NAME ##_## ENGINE) #endif //SD_PLATFORM_BOILERPLATE_H diff --git a/libnd4j/include/play.h b/libnd4j/include/play.h index ecafe84ea..d0fecee82 100644 --- a/libnd4j/include/play.h +++ b/libnd4j/include/play.h @@ -21,8 +21,9 @@ #ifndef LIBND4J_PLAY_H #define LIBND4J_PLAY_H -#include - +//#include +#include +/* #define DATA_TYPES \ (DATA_FLOAT, float) ,\ (DATA_DOUBLE, double) @@ -41,6 +42,9 @@ BUILD_SINGLE_TEMPLATE_TWICE(template class functionName, , DATA_TYPES) + */ + +DECLARE_PLATFORM(conv2d, ENGINE_CPU) //BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functionName, (signature), DATA_TYPES, Y_TYPES); diff --git a/libnd4j/include/type_boilerplate.h b/libnd4j/include/type_boilerplate.h index bd235726a..af0fe369d 100644 --- a/libnd4j/include/type_boilerplate.h +++ b/libnd4j/include/type_boilerplate.h @@ -634,7 +634,7 @@ #define BROADCAST(NAME) nd4j::BroadcastOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) - +#define ALL_STRINGS nd4j::DataType::UTF8, nd4j::DataType::UTF16, nd4j::DataType::UTF32 #define ALL_INDICES nd4j::DataType::INT32, nd4j::DataType::INT64 #define ALL_INTS nd4j::DataType::INT8, nd4j::DataType::UINT8, nd4j::DataType::INT16, nd4j::DataType::UINT16, nd4j::DataType::INT32, nd4j::DataType::UINT32, nd4j::DataType::INT64, nd4j::DataType::UINT64 #define ALL_FLOATS nd4j::DataType::HALF, nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16 diff --git a/libnd4j/include/types/bfloat16.h b/libnd4j/include/types/bfloat16.h index 9b8081495..847c2ebda 100644 --- a/libnd4j/include/types/bfloat16.h +++ b/libnd4j/include/types/bfloat16.h @@ -47,489 +47,221 @@ //{ struct bfloat16 { - public: - int16_t _data; - /* constexpr */ local_def bfloat16() { _data = 0; } + private: + template + struct isNumericType { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; + // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; - template - local_def /*explicit*/ bfloat16(const T& rhs) { - assign(rhs); - } + public: + int16_t _data; -// local_def bfloat16(float rhs) { -// assign(rhs); -// } -// -// local_def bfloat16(double rhs) { -// assign(rhs); -// } + local_def bfloat16() { + _data = 0; + } - local_def operator float() const { - int32_t temp = this->_data << 16; //((sign << 31) | (exponent << 23) | mantissa); + template ::value>::type> + local_def bfloat16(const T& rhs) { + *this = rhs; + } - return *reinterpret_cast(&temp); - } + local_def operator float() const { + int32_t temp = this->_data << 16; //((sign << 31) | (exponent << 23) | mantissa); + return *reinterpret_cast(&temp); + } - local_def explicit operator double() const { return static_cast(static_cast(*this)); } - local_def explicit operator unsigned long long() const { return static_cast(static_cast(*this)); } - local_def explicit operator int16_t() const { return static_cast(static_cast(*this)); } - local_def explicit operator uint16_t() const { return static_cast(static_cast(*this)); } - local_def explicit operator uint32_t() const { return static_cast(static_cast(*this)); } - local_def explicit operator uint8_t() const { return static_cast(static_cast(*this)); } - local_def explicit operator int8_t() const { return static_cast(static_cast(*this)); } - local_def explicit operator int() const { return static_cast(static_cast(*this)); } - local_def explicit operator Nd4jLong() const { return static_cast(static_cast(*this)); } - local_def explicit operator bool() const { return this->_data == 0 ? false : true; } - local_def explicit operator float16() const { return static_cast(static_cast(*this)); } + local_def explicit operator bool() const { + return this->_data == 0 ? false : true; + } - template - local_def bfloat16& operator=(const T& rhs) { assign(rhs); return *this; } + template ::value>::type> + local_def explicit operator T() const { + return static_cast(static_cast(*this)); + } - local_def void assign(unsigned int rhs) { - // may be a better way ? - assign((float)rhs); - } + local_def bfloat16& operator=(const bool rhs) { + *this = (float)rhs ? 1.f: 0.f; + return *this; + } - local_def void assign(int rhs) { - // may be a better way ? - assign((float)rhs); - } + local_def bfloat16& operator=(const float& rhs) { + #ifdef __CUDACC__ + if(::isnan(rhs)) { + _data = bfloat16::nan(); + return *this; + } + #endif + auto x = *reinterpret_cast(& const_cast(rhs)); + uint32_t lsb = (x >> 16) & 1; + uint32_t rounding_bias = 0x7fff + lsb; + x += rounding_bias; + this->_data = static_cast(x >> 16); - local_def void assign(double rhs) { - assign((float)rhs); - } + return *this; + } - local_def void assign(long long rhs) { - assign((float)rhs); - } + local_def bfloat16& operator=(const bfloat16& rhs) { + _data = rhs._data; + return *this; + } - local_def void assign(long int rhs) { - assign((float)rhs); - } + template ::value>::type> + local_def bfloat16& operator=(const T& rhs) { + *this = (float)rhs; + return *this; + } - local_def void assign(long unsigned int rhs) { - assign((float)rhs); - } + local_def friend bool operator==(const bfloat16& a, const bfloat16& b) { return (a._data == b._data); } + local_def friend bool operator!=(const bfloat16& a, const bfloat16& b) { return !(a == b); } + local_def friend bool operator<(const bfloat16& a, const bfloat16& b) { return (float)a < (float)b; } + local_def friend bool operator>(const bfloat16& a, const bfloat16& b) { return (float)a > (float)b; } + local_def friend bool operator<=(const bfloat16& a, const bfloat16& b) { return (float)a <= (float)b; } + local_def friend bool operator>=(const bfloat16& a, const bfloat16& b) { return (float)a >= (float)b; } - local_def void assign(unsigned short rhs) { - assign((float)rhs); - } + local_def friend bfloat16 operator+(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a + (float)b); } + local_def friend bfloat16 operator-(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a - (float)b); } + local_def friend bfloat16 operator*(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a * (float)b); } + local_def friend bfloat16 operator/(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a / (float)b); } - local_def void assign(float16 rhs) { - assign((float)rhs); - } + template ::value>::type> + local_def friend bfloat16 operator+(const bfloat16& a, const T& b) { return a + static_cast(b); } + template ::value>::type> + local_def friend bfloat16 operator+(const T& a, const bfloat16& b) { return static_cast(a) + b; } - local_def void assign(long long unsigned int rhs) { - assign((float)rhs); - } + template ::value>::type> + local_def friend bfloat16 operator-(const bfloat16& a, const T& b) { return a - static_cast(b); } + template ::value>::type> + local_def friend bfloat16 operator-(const T& a, const bfloat16& b) { return static_cast(a) - b; } - local_def void assign(float rhs) { -#ifdef __CUDACC__ - if(::isnan(rhs)) { - _data = bfloat16::nan(); - return; - } -#endif - auto x = *reinterpret_cast(&rhs); - uint32_t lsb = (x >> 16) & 1; - uint32_t rounding_bias = 0x7fff + lsb; - x += rounding_bias; - this->_data = static_cast(x >> 16); - } + template ::value>::type> + local_def friend bfloat16 operator*(const bfloat16& a, const T& b) { return a * static_cast(b); } + template ::value>::type> + local_def friend bfloat16 operator*(const T& a, const bfloat16& b) { return static_cast(a) * b; } - local_def void assign(const bfloat16& rhs) { - _data = rhs._data; - } + template ::value>::type> + local_def friend bfloat16 operator/(const bfloat16& a, const T& b) { return a / static_cast(b); } + template ::value>::type> + local_def friend bfloat16 operator/(const T& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16& operator+=(bfloat16 rhs) { assign((float)(*this) + (float)rhs); return *this; } + template ::value>::type> + local_def friend bool operator==(const bfloat16& a, const T& b) { return a == static_cast(b); } + template ::value>::type> + local_def friend bool operator==(const T& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bfloat16& operator-=(bfloat16 rhs) { assign((float)*this - (float)rhs); return *this; } + template ::value>::type> + local_def friend bool operator!=(const bfloat16& a, const T& b) { return a != static_cast(b); } + template ::value>::type> + local_def friend bool operator!=(const T& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bfloat16& operator*=(bfloat16 rhs) { assign((float)*this * (float)rhs); return *this; } + template ::value>::type> + local_def friend bool operator<(const bfloat16& a, const T& b) { return a < static_cast(b); } + template ::value>::type> + local_def friend bool operator<(const T& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bfloat16& operator/=(bfloat16 rhs) { assign((float)*this / (float)rhs); return *this; } + template ::value>::type> + local_def friend bool operator>(const bfloat16& a, const T& b) { return a > static_cast(b); } + template ::value>::type> + local_def friend bool operator>(const T& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bfloat16& operator+=(float rhs) { assign((float)*this + rhs); return *this; } + template ::value>::type> + local_def friend bool operator<=(const bfloat16& a, const T& b) { return a <= static_cast(b); } + template ::value>::type> + local_def friend bool operator<=(const T& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bfloat16& operator-=(float rhs) { assign((float)*this - rhs); return *this; } + template ::value>::type> + local_def friend bool operator>=(const bfloat16& a, const T& b) { return a >= static_cast(b); } + template ::value>::type> + local_def friend bool operator>=(const T& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bfloat16& operator*=(float rhs) { assign((float)*this * rhs); return *this; } + local_def bfloat16& operator+=(bfloat16 rhs) { *this = (float)(*this) + (float)rhs; return *this; } - local_def bfloat16& operator/=(float rhs) { assign((float)*this / rhs); return *this; } + local_def bfloat16& operator-=(bfloat16 rhs) { *this = (float)(*this) - (float)rhs; return *this; } - local_def bfloat16& operator++() { *this += 1.f; return *this; } + local_def bfloat16& operator*=(bfloat16 rhs) { *this = (float)(*this) * (float)rhs; return *this; } - local_def bfloat16& operator--() { *this -= 1.f; return *this; } + local_def bfloat16& operator/=(bfloat16 rhs) { *this = (float)(*this) / (float)rhs; return *this; } - local_def bfloat16 operator++(int i) { *this += i; return *this; } + template ::value>::type> + local_def bfloat16& operator+=(const T& rhs) { *this = *this + rhs; return *this; } - local_def bfloat16 operator--(int i) { *this -= i; return *this; } + template ::value>::type> + local_def bfloat16& operator-=(const T& rhs) { *this = *this - rhs; return *this; } - local_def std::ostream& operator<<(std::ostream& os) { - os << static_cast(*this); - return os; - } - local_def static bfloat16 min() { - bfloat16 res; - res._data = 0xFF7F; - return res; - } - local_def static bfloat16 max() { - bfloat16 res; - res._data = 0x7F7F; - return res; + template ::value>::type> + local_def bfloat16& operator*=(const T& rhs) { *this = *this * rhs; return *this; } - } - local_def static bfloat16 eps() { - bfloat16 res; - res._data = 0x3C00; - return res; - } + template ::value>::type> + local_def bfloat16& operator/=(const T& rhs) { *this = *this / rhs; return *this; } - local_def static bfloat16 inf() { - bfloat16 res; - res._data = 0x3C00; - return res; - } + local_def bfloat16& operator++() { *this = (float)*this + (float)1.f; return *this; } - local_def static bfloat16 nan() { - bfloat16 res; - res._data = 0x7FC0; - return res; - } - }; + local_def bfloat16& operator--() { *this = (float)*this - (float)1.f; return *this; } - local_def bool operator==(const bfloat16& a, const bfloat16& b) { return (a._data == b._data); } + local_def bfloat16 operator++(int) { *this = (float)*this + (float)1.f; return *this; } -// template -// local_def bool operator==(const bfloat16& a, const T& b) { return (a == (bfloat16) b); } + local_def bfloat16 operator--(int) { *this = (float)*this - (float)1.f; return *this; } - local_def bool operator!=(const bfloat16& a, const bfloat16& b) { return !(a == b); } -// - local_def bool operator<(const bfloat16& a, const bfloat16& b) { return (float)a < (float)b; } - - local_def bool operator>(const bfloat16& a, const bfloat16& b) { return (float)a > (float)b; } - - template - local_def bool operator>(const bfloat16& a, const T& b) { return (float)a > (float)b; } - - local_def bool operator<=(const bfloat16& a, const bfloat16& b) { return (float)a <= (float)b; } - template - local_def bool operator<=(const bfloat16& a, const T& b) { return (float)a <= (float)b; } - - local_def bool operator>=(const bfloat16& a, const bfloat16& b) { return (float)a >= (float)b; } - - local_def bfloat16 operator+(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a + (float)b); } - local_def bfloat16 operator-(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a - (float)b); } - local_def bfloat16 operator*(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a * (float)b); } - local_def bfloat16 operator/(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a / (float)b); } -// - - local_def bfloat16 operator+(const bfloat16& a, const double& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const float& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const float16& b) { return static_cast(static_cast(a) + static_cast(b)); } - local_def bfloat16 operator+(const float16& a, const bfloat16& b) { return static_cast(static_cast(a) + static_cast(b)); } - local_def bfloat16 operator+(const bfloat16& a, const int& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const unsigned int& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const long long& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const unsigned long long& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const long int& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const bool& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const int8_t& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const uint8_t& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const int16_t& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const uint16_t& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const long unsigned int& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const int8_t& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const uint8_t& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const int16_t& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const uint16_t& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const bool& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const int& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const unsigned int& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const long long& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const unsigned long long& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const long int& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const float& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const double& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const long unsigned int& a, const bfloat16& b) { return static_cast(a) + b; } - - local_def bfloat16 operator-(const bfloat16& a, const double& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const float& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const float16& b) { return static_cast(static_cast(a) - static_cast(b)); } - local_def bfloat16 operator-(const float16& a, const bfloat16& b) { return static_cast(static_cast(a) - static_cast(b)); } - local_def bfloat16 operator-(const bfloat16& a, const int& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const unsigned int& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const long long& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const unsigned long long& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const long int& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const bool& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const int8_t& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const uint8_t& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const int16_t& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const uint16_t& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const long unsigned int& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const int8_t& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const uint8_t& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const int16_t& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const uint16_t& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const bool& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const int& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const unsigned int& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const long long& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const unsigned long long& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const long int& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const float& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const double& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const long unsigned int& a, const bfloat16& b) { return static_cast(a) - b; } - - local_def bfloat16 operator/(const bfloat16& a, const double& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const float& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const float16& b) { return static_cast((float)a / (float)b); } - local_def bfloat16 operator/(const float16& a, const bfloat16& b) { return static_cast((float)a / (float)b); } - local_def bfloat16 operator/(const bfloat16& a, const int& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const unsigned int& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const long long& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const unsigned long long& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const long int& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const bool& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const int8_t& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const uint8_t& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const int16_t& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const uint16_t& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const long unsigned int& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const int8_t& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const uint8_t& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const int16_t& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const uint16_t& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const bool& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const int& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const unsigned int& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const long long& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const unsigned long long& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const long int& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const float& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const double& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const long unsigned int& a, const bfloat16& b) { return static_cast(a) / b; } - - local_def bfloat16 operator*(const bfloat16& a, const double& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const float& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const float16& b) { return static_cast((float)a * (float)b); } - local_def bfloat16 operator*(const float16& a, const bfloat16& b) { return static_cast((float)a * (float)b); } - local_def bfloat16 operator*(const bfloat16& a, const int& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const unsigned int& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const long long& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const unsigned long long& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const long int& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const bool& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const int8_t& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const uint8_t& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const int16_t& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const uint16_t& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const long unsigned int& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const int8_t& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const uint8_t& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const int16_t& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const uint16_t& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const bool& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const int& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const unsigned int& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const long long& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const unsigned long long& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const long int& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const float& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const double& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const long unsigned int& a, const bfloat16& b) { return static_cast(a) * b; } - - local_def bool operator==(const bfloat16& a, const float& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const float16& b) { return (float)a == (float)(b); } - local_def bool operator==(const bfloat16& a, const double& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const int& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const unsigned int& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const long long& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const unsigned long long& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const long int& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const int8_t& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const uint8_t& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const int16_t& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const uint16_t& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const bool& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const long unsigned int& b) { return a == static_cast(b); } - local_def bool operator==(const bool& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const int8_t& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const uint8_t& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const int16_t& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const uint16_t& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const int& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const unsigned int& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const long long& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const unsigned long long& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const long int& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const float& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const double& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const long unsigned int& a, const bfloat16& b) { return static_cast(a) == b; } - - local_def bool operator!=(const bfloat16& a, const float& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const float16& b) { return (float)a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const double& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const int& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const unsigned int& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const long long& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const unsigned long long& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const long int& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const int8_t& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const uint8_t& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const int16_t& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const uint16_t& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const bool& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const long unsigned int& b) { return a != static_cast(b); } - local_def bool operator!=(const bool& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int8_t& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const uint8_t& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int16_t& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const uint16_t& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const unsigned int& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long long& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const unsigned long long& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long int& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const float& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const double& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long unsigned int& a, const bfloat16& b) { return static_cast(a) != b; } - - local_def bool operator<(const bfloat16& a, const float& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const float16& b) { return (float)a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const double& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const int& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const unsigned int& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const long long& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const unsigned long long& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const long int& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const int8_t& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const uint8_t& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const int16_t& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const uint16_t& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const bool& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const long unsigned int& b) { return a < static_cast(b); } - local_def bool operator<(const bool& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const int8_t& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const uint8_t& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const int16_t& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const uint16_t& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const int& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const unsigned int& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const long long& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const unsigned long long& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const long int& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const float& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const double& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const long unsigned int& a, const bfloat16& b) { return static_cast(a) < b; } - - local_def bool operator>(const bfloat16& a, const float& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const float16& b) { return (float)a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const double& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const int& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const unsigned int& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const long long& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const unsigned long long& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const long int& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const int8_t& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const uint8_t& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const int16_t& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const uint16_t& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const bool& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const long unsigned int& b) { return a > static_cast(b); } - local_def bool operator>(const bool& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const int8_t& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const uint8_t& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const int16_t& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const uint16_t& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const int& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const unsigned int& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const long long& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const unsigned long long& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const long int& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const float& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const double& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const long unsigned int& a, const bfloat16& b) { return static_cast(a) > b; } - - local_def bool operator<=(const bfloat16& a, const float& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const float16& b) { return (float)a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const double& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const int& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const unsigned int& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const long long& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const unsigned long long& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const long int& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const int8_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const uint8_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const int16_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const uint16_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const bool& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const long unsigned int& b) { return a <= static_cast(b); } - local_def bool operator<=(const bool& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int8_t& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const uint8_t& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int16_t& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const uint16_t& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const unsigned int& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long long& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const unsigned long long& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long int& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const float& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const double& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long unsigned int& a, const bfloat16& b) { return static_cast(a) <= b; } - - local_def bool operator>=(const bfloat16& a, const float& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const float16& b) { return (float)a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const double& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const int& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const unsigned int& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const long long& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const unsigned long long& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const long int& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const int8_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const uint8_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const int16_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const uint16_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const bool& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const long unsigned int& b) { return a >= static_cast(b); } - local_def bool operator>=(const bool& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int8_t& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const uint8_t& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int16_t& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const uint16_t& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const unsigned int& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long long& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const unsigned long long& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long int& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const float& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const double& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long unsigned int& a, const bfloat16& b) { return static_cast(a) >= b; } - - - local_def std::ostream& operator<<(std::ostream &os, const bfloat16 &f) { - os << static_cast(f); - return os; - } + local_def bfloat16 operator-() const { + return 0.f - (float)*this; + } - local_def bfloat16 /* constexpr */ operator+(const bfloat16& h) { return h; } - local_def bfloat16 operator - (const bfloat16& h) { - auto temp = h._data; - temp ^= 0x8000; - bfloat16 t; - t._data = temp; - return t; -} + // local_def std::ostream& operator<<(std::ostream& os) { + // os << static_cast(*this); + // return os; + // } + local_def static bfloat16 min() { + bfloat16 res; + res._data = 0xFF7F; + return res; + } + local_def static bfloat16 max() { + bfloat16 res; + res._data = 0x7F7F; + return res; + + } + local_def static bfloat16 eps() { + bfloat16 res; + res._data = 0x3C00; + return res; + } + + local_def static bfloat16 inf() { + bfloat16 res; + res._data = 0x3C00; + return res; + } + + local_def static bfloat16 nan() { + bfloat16 res; + res._data = 0x7FC0; + return res; + } +}; + + + +// local_def std::ostream& operator<<(std::ostream &os, const bfloat16 &f) { +// os << static_cast(f); +// return os; +// } + + +// local_def bfloat16 /* constexpr */ operator+(const bfloat16& h) { return h; } + +// local_def bfloat16 operator - (const bfloat16& h) { +// auto temp = h._data; +// temp ^= 0x8000; +// bfloat16 t; +// t._data = temp; +// return t; +// } // WARNING: this implementation only for avoid cyclic references between float16 and bfloat16 types. -local_def void float16::assign(const bfloat16& rhs) { - assign((float)rhs); -} +// local_def void float16::assign(const bfloat16& rhs) { +// assign((float)rhs); +// } //} // namespace diff --git a/libnd4j/include/types/float16.h b/libnd4j/include/types/float16.h index 0cc75daed..4aa0d5d66 100644 --- a/libnd4j/include/types/float16.h +++ b/libnd4j/include/types/float16.h @@ -25,7 +25,6 @@ #include #endif - struct bfloat16; #ifdef __CUDACC__ @@ -224,505 +223,258 @@ local_def ihalf cpu_float2ihalf_rn(float f) } #endif - struct float16 - { - public: - ihalf data; - local_def float16() { *data.getXP() = 0; } +struct float16 { - template - local_def float16(const T& rhs) { - assign(rhs); - } + private: + template + struct isNumericType { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; };// || std::is_same::value; }; + // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; - local_def operator float() const { -#ifdef __CUDA_ARCH__ - return __half2float(data); -#else - return cpu_ihalf2float(data); -#endif - } + public: + ihalf data; + local_def float16() { *data.getXP() = 0; } - local_def explicit operator double() const { - return static_cast(static_cast(*this)); - } + template ::value || std::is_same::value>::type> + local_def float16(const T& rhs) { + *this = rhs; + } - local_def explicit operator Nd4jLong() const { - return static_cast(static_cast(*this)); - } - - local_def explicit operator int() const { - return static_cast(static_cast(*this)); - } - - local_def explicit operator bool() const { - return static_cast(*this) > 0.0f; - } - - local_def explicit operator int16_t() const { - return static_cast(static_cast(*this)); - } - - local_def explicit operator uint16_t() const { - return static_cast(static_cast(*this)); - } - - local_def explicit operator uint8_t() const { - return static_cast(static_cast(*this)); - } - - local_def explicit operator int8_t() const { - return static_cast(static_cast(*this)); - } - - local_def operator half() const { return data; } - - template - local_def float16& operator=(const T& rhs) { assign(rhs); return *this; } - - local_def void assign(unsigned int rhs) { - assign((float)rhs); - } - - local_def void assign(int rhs) { - assign((float)rhs); - } - - local_def void assign(double rhs) { - assign((float)rhs); - } - - local_def void assign(long long rhs) { - assign((float)rhs); - } - - local_def void assign(long int rhs) { - assign((float)rhs); - } - - local_def void assign(const bool rhs) { - assign(rhs ? 1.0f : 0.0f); - } - - local_def void assign(long unsigned int rhs) { - assign((float)rhs); - } - - local_def void assign(unsigned short rhs) { - *data.getXP() = rhs; - } - - local_def void assign(long long unsigned int rhs) { - assign((float)rhs); - } - - local_def void assign(float rhs) { -#ifdef __CUDA_ARCH__ - auto t = __float2half_rn(rhs); - auto b = *(data.getXP()); - -#ifdef CUDA_8 - *(data.getXP()) = t; -#else - data.assign(t); -#endif - -#else - data = cpu_float2ihalf_rn(rhs); -#endif - } - - local_def void assign(const ihalf& rhs) { - *data.getXP() = ((ihalf) rhs).getX(); - } - - local_def void assign(const bfloat16& rhs); - -#ifdef __CUDACC__ - local_def void assign(const half& rhs) { - data.assign(rhs); - } -#endif - - local_def void assign(const float16& rhs) { - data = rhs.data; - } - - local_def float16& operator+=(float16 rhs) { assign((float)*this + rhs); return *this; } - - local_def float16& operator-=(float16 rhs) { assign((float)*this - rhs); return *this; } - - local_def float16& operator*=(float16 rhs) { assign((float)*this * rhs); return *this; } - - local_def float16& operator/=(float16 rhs) { assign((float)*this / rhs); return *this; } - - local_def float16& operator+=(float rhs) { assign((float)*this + rhs); return *this; } - - local_def float16& operator-=(float rhs) { assign((float)*this - rhs); return *this; } - - local_def float16& operator*=(float rhs) { assign((float)*this * rhs); return *this; } - - local_def float16& operator/=(float rhs) { assign((float)*this / rhs); return *this; } - - local_def float16& operator++() { assign(*this + 1.f); return *this; } - - local_def float16& operator--() { assign(*this - 1.f); return *this; } - - local_def float16 operator++(int i) { assign(*this + (float)i); return *this; } - - local_def float16 operator--(int i) { assign(*this - (float)i); return *this; } - - local_def std::ostream& operator<<(std::ostream& os) { - os << static_cast(*this); - return os; - } - }; - - -#ifdef NATIVE_HALFS - local_def bool operator==(const float16& a, const float16& b) { return __hequ(a.data, b.data); } -#else - local_def bool operator==(const float16& a, const float16& b) { return ishequ_(((ihalf) a.data).getX(), ((ihalf)b.data).getX()); } -#endif - -#ifdef NATIVE_HALFS - local_def bool operator!=(const float16& a, const float16& b) { return !(__hequ(a.data, b.data)); } -#else - local_def bool operator!=(const float16& a, const float16& b) { return !(a == b); } -#endif - -#ifdef NATIVE_HALFS - local_def bool operator<(const float16& a, const float16& b) { return __hlt(a.data, b.data); } -#else - local_def bool operator<(const float16& a, const float16& b) { return (float)a < (float)b; } -#endif - -#ifdef NATIVE_HALFS - local_def bool operator>(const float16& a, const float16& b) { return __hgt(a.data, b.data); } -#else - local_def bool operator>(const float16& a, const float16& b) { return (float)a > (float)b; } -#endif - - template - local_def bool operator>(const float16& a, const T& b) { return (float)a > (float)b; } - -#ifdef NATIVE_HALFS - local_def bool operator<=(const float16& a, const float16& b) { return __hle(a.data, b.data); } -#else - local_def bool operator<=(const float16& a, const float16& b) { return (float)a <= (float)b; } -#endif - template - local_def bool operator<=(const float16& a, const T& b) { return (float)a <= (float)b; } - -#ifdef NATIVE_HALFS - local_def bool operator>=(const float16& a, const float16& b) { return __hge(a.data, b.data); } -#else - local_def bool operator>=(const float16& a, const float16& b) { return (float)a >= (float)b; } -#endif - -#ifdef NATIVE_HALFS - local_def float16 operator+(const float16& a, const float16& b) { return __hadd(a.data, b.data); } - - local_def float16 operator-(const float16& a, const float16& b) { return __hsub(a.data, b.data); } - - local_def float16 operator*(const float16& a, const float16& b) { return __hmul(a.data, b.data); } - - local_def float16 operator/(const float16& a, const float16& b) { - #ifdef CUDA_8 - return hdiv(a.data, b.data); - #else - return __hdiv(a.data, b.data); + local_def float16(const half& rhs) { + #ifdef __CUDACC__ + data.assign(rhs); #endif - } -#else - local_def float16 operator+(const float16& a, const float16& b) { return float16((float)a + (float)b); } - local_def float16 operator-(const float16& a, const float16& b) { return float16((float)a - (float)b); } - local_def float16 operator*(const float16& a, const float16& b) { return float16((float)a * (float)b); } - local_def float16 operator/(const float16& a, const float16& b) { return float16((float)a / (float)b); } -#endif + } - local_def float16 operator+(const float16& a, const double& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const float& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const int& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const unsigned int& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const long long& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const unsigned long long& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const long int& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const bool& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const int8_t& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const uint8_t& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const int16_t& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const uint16_t& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const long unsigned int& b) { return a + static_cast(b); } - local_def float16 operator+(const int8_t& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const uint8_t& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const int16_t& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const uint16_t& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const bool& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const int& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const unsigned int& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const long long& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const unsigned long long& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const long int& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const float& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const double& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const long unsigned int& a, const float16& b) { return static_cast(a) + b; } + local_def operator float() const { + #ifdef __CUDA_ARCH__ + return __half2float(data); + #else + return cpu_ihalf2float(data); + #endif + } - local_def float16 operator-(const float16& a, const double& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const float& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const int& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const unsigned int& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const long long& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const unsigned long long& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const long int& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const bool& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const int8_t& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const uint8_t& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const int16_t& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const uint16_t& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const long unsigned int& b) { return a - static_cast(b); } - local_def float16 operator-(const int8_t& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const uint8_t& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const uint16_t& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const int16_t& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const bool& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const int& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const unsigned int& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const long long& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const unsigned long long& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const long int& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const float& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const double& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const long unsigned int& a, const float16& b) { return static_cast(a) - b; } + local_def explicit operator bool() const { + return static_cast(*this) != 0.0f; + } - local_def float16 operator/(const float16& a, const double& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const float& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const int& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const unsigned int& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const long long& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const unsigned long long& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const long int& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const bool& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const int8_t& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const uint8_t& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const int16_t& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const uint16_t& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const long unsigned int& b) { return a / static_cast(b); } - local_def float16 operator/(const int8_t& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const uint8_t& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const uint16_t& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const int16_t& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const bool& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const int& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const unsigned int& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const long long& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const unsigned long long& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const long int& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const float& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const double& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const long unsigned int& a, const float16& b) { return static_cast(a) / b; } - - local_def float16 operator*(const float16& a, const double& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const float& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const int& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const unsigned int& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const long long& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const unsigned long long& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const long int& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const bool& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const int8_t& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const uint8_t& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const int16_t& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const uint16_t& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const long unsigned int& b) { return a * static_cast(b); } - local_def float16 operator*(const int8_t& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const uint8_t& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const uint16_t& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const int16_t& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const bool& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const int& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const unsigned int& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const long long& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const unsigned long long& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const long int& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const float& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const double& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const long unsigned int& a, const float16& b) { return static_cast(a) * b; } + local_def explicit operator half() const { + return data; + } - local_def bool operator==(const float16& a, const float& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const double& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const int& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const unsigned int& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const long long& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const unsigned long long& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const long int& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const int8_t& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const uint8_t& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const int16_t& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const uint16_t& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const bool& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const long unsigned int& b) { return a == static_cast(b); } - local_def bool operator==(const bool& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const int8_t& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const uint8_t& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const uint16_t& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const int16_t& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const int& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const unsigned int& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const long long& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const unsigned long long& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const long int& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const float& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const double& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const long unsigned int& a, const float16& b) { return static_cast(a) == b; } + template ::value>::type> + local_def explicit operator T() const { + return static_cast(static_cast(*this)); + } - local_def bool operator!=(const float16& a, const float& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const double& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const int& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const unsigned int& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const long long& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const unsigned long long& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const long int& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const int8_t& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const uint8_t& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const int16_t& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const uint16_t& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const bool& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const long unsigned int& b) { return a != static_cast(b); } - local_def bool operator!=(const bool& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int8_t& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const uint8_t& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int16_t& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const uint16_t& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const unsigned int& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long long& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const unsigned long long& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long int& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const float& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const double& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long unsigned int& a, const float16& b) { return static_cast(a) != b; } + local_def float16& operator=(const float& rhs) { + #ifdef __CUDA_ARCH__ + auto t = __float2half_rn(rhs); + auto b = *(data.getXP()); - local_def bool operator<(const float16& a, const float& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const double& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const int& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const unsigned int& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const long long& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const unsigned long long& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const long int& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const int8_t& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const uint8_t& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const int16_t& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const uint16_t& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const bool& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const long unsigned int& b) { return a < static_cast(b); } - local_def bool operator<(const bool& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const int8_t& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const uint8_t& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const int16_t& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const uint16_t& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const int& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const unsigned int& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const long long& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const unsigned long long& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const long int& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const float& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const double& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const long unsigned int& a, const float16& b) { return static_cast(a) < b; } + #ifdef CUDA_8 + *(data.getXP()) = t; + #else + data.assign(t); + #endif - local_def bool operator>(const float16& a, const float& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const double& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const int& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const unsigned int& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const long long& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const unsigned long long& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const long int& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const int8_t& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const uint8_t& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const int16_t& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const uint16_t& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const bool& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const long unsigned int& b) { return a > static_cast(b); } - local_def bool operator>(const bool& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const int8_t& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const uint8_t& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const int16_t& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const uint16_t& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const int& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const unsigned int& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const long long& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const unsigned long long& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const long int& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const float& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const double& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const long unsigned int& a, const float16& b) { return static_cast(a) > b; } - - local_def bool operator<=(const float16& a, const float& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const double& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const int& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const unsigned int& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const long long& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const unsigned long long& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const long int& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const int8_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const uint8_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const int16_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const uint16_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const bool& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const long unsigned int& b) { return a <= static_cast(b); } - local_def bool operator<=(const bool& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int8_t& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const uint8_t& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int16_t& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const uint16_t& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const unsigned int& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long long& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const unsigned long long& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long int& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const float& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const double& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long unsigned int& a, const float16& b) { return static_cast(a) <= b; } + #else + data = cpu_float2ihalf_rn(rhs); + #endif - local_def bool operator>=(const float16& a, const float& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const double& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const int& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const unsigned int& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const long long& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const unsigned long long& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const long int& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const int8_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const uint8_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const int16_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const uint16_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const bool& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const long unsigned int& b) { return a >= static_cast(b); } - local_def bool operator>=(const bool& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int8_t& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const uint8_t& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int16_t& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const uint16_t& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const unsigned int& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long long& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const unsigned long long& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long int& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const float& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const double& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long unsigned int& a, const float16& b) { return static_cast(a) >= b; } - + return *this; + } - local_def std::ostream& operator<<(std::ostream &os, const float16 &f) { - os << static_cast(f); - return os; - } + local_def float16& operator=(const unsigned short rhs) { + *data.getXP() = rhs; + return *this; + } - local_def float16 operator+(const float16& h) { return h; } + local_def float16& operator=(const bool rhs) { + *this = (float)rhs ? 1.f: 0.f; + return *this; + } - local_def float16 operator - (const float16& h) { - const ihalf * tmp = &h.data; - return float16(hneg(tmp->getX())); -} + local_def float16& operator=(const ihalf& rhs) { + *data.getXP() = ((ihalf) rhs).getX(); + return *this; + } + + #ifdef __CUDACC__ + local_def float16& operator=(const half& rhs) { + data.assign(rhs); + return *this; + } + #endif + + local_def float16& operator=(const float16& rhs) { + data = rhs.data; + return *this; + } + + template ::value || std::is_same::value>::type> + local_def float16& operator=(const T& rhs) { + *this = (float)rhs; + return *this; + } + + #ifdef NATIVE_HALFS + local_def friend bool operator==(const float16& a, const float16& b) { return __hequ(a.data, b.data); } + #else + local_def friend bool operator==(const float16& a, const float16& b) { return ishequ_(((ihalf) a.data).getX(), ((ihalf)b.data).getX()); } + #endif + + #ifdef NATIVE_HALFS + local_def friend bool operator!=(const float16& a, const float16& b) { return !(__hequ(a.data, b.data)); } + #else + local_def friend bool operator!=(const float16& a, const float16& b) { return !(a == b); } + #endif + + #ifdef NATIVE_HALFS + local_def friend bool operator<(const float16& a, const float16& b) { return __hlt(a.data, b.data); } + #else + local_def friend bool operator<(const float16& a, const float16& b) { return (float)a < (float)b; } + #endif + + #ifdef NATIVE_HALFS + local_def friend bool operator>(const float16& a, const float16& b) { return __hgt(a.data, b.data); } + #else + local_def friend bool operator>(const float16& a, const float16& b) { return (float)a > (float)b; } + #endif + + #ifdef NATIVE_HALFS + local_def friend bool operator<=(const float16& a, const float16& b) { return __hle(a.data, b.data); } + #else + local_def friend bool operator<=(const float16& a, const float16& b) { return (float)a <= (float)b; } + #endif + + #ifdef NATIVE_HALFS + local_def friend bool operator>=(const float16& a, const float16& b) { return __hge(a.data, b.data); } + #else + local_def friend bool operator>=(const float16& a, const float16& b) { return (float)a >= (float)b; } + #endif + + #ifdef NATIVE_HALFS + local_def friend float16 operator+(const float16& a, const float16& b) { return __hadd(a.data, b.data); } + + local_def friend float16 operator-(const float16& a, const float16& b) { return __hsub(a.data, b.data); } + + local_def friend float16 operator*(const float16& a, const float16& b) { return __hmul(a.data, b.data); } + + local_def friend float16 operator/(const float16& a, const float16& b) { + #ifdef CUDA_8 + return hdiv(a.data, b.data); + #else + return __hdiv(a.data, b.data); + #endif + } + #else + local_def friend float16 operator+(const float16& a, const float16& b) { return float16((float)a + (float)b); } + local_def friend float16 operator-(const float16& a, const float16& b) { return float16((float)a - (float)b); } + local_def friend float16 operator*(const float16& a, const float16& b) { return float16((float)a * (float)b); } + local_def friend float16 operator/(const float16& a, const float16& b) { return float16((float)a / (float)b); } + #endif + + template ::value>::type> + local_def friend float16 operator+(const float16& a, const T& b) { return a + static_cast(b); } + template ::value>::type> + local_def friend float16 operator+(const T& a, const float16& b) { return static_cast(a) + b; } + + template ::value>::type> + local_def friend float16 operator-(const float16& a, const T& b) { return a - static_cast(b); } + template ::value>::type> + local_def friend float16 operator-(const T& a, const float16& b) { return static_cast(a) - b; } + + template ::value>::type> + local_def friend float16 operator*(const float16& a, const T& b) { return a * static_cast(b); } + template ::value>::type> + local_def friend float16 operator*(const T& a, const float16& b) { return static_cast(a) * b; } + + template ::value>::type> + local_def friend float16 operator/(const float16& a, const T& b) { return a / static_cast(b); } + template ::value>::type> + local_def friend float16 operator/(const T& a, const float16& b) { return static_cast(a) / b; } + + template ::value>::type> + local_def friend bool operator==(const float16& a, const T& b) { return a == static_cast(b); } + template ::value>::type> + local_def friend bool operator==(const T& a, const float16& b) { return static_cast(a) == b; } + + template ::value>::type> + local_def friend bool operator!=(const float16& a, const T& b) { return a != static_cast(b); } + template ::value>::type> + local_def friend bool operator!=(const T& a, const float16& b) { return static_cast(a) != b; } + + template ::value>::type> + local_def friend bool operator<(const float16& a, const T& b) { return a < static_cast(b); } + template ::value>::type> + local_def friend bool operator<(const T& a, const float16& b) { return static_cast(a) < b; } + + template ::value>::type> + local_def friend bool operator>(const float16& a, const T& b) { return a > static_cast(b); } + template ::value>::type> + local_def friend bool operator>(const T& a, const float16& b) { return static_cast(a) > b; } + + template ::value>::type> + local_def friend bool operator<=(const float16& a, const T& b) { return a <= static_cast(b); } + template ::value>::type> + local_def friend bool operator<=(const T& a, const float16& b) { return static_cast(a) <= b; } + + template ::value>::type> + local_def friend bool operator>=(const float16& a, const T& b) { return a >= static_cast(b); } + template ::value>::type> + local_def friend bool operator>=(const T& a, const float16& b) { return static_cast(a) >= b; } + + local_def float16& operator+=(float16 rhs) { *this = (float)*this + (float)rhs; return *this; } + + local_def float16& operator-=(float16 rhs) { *this = (float)*this - (float)rhs; return *this; } + + local_def float16& operator*=(float16 rhs) { *this = (float)*this * (float)rhs; return *this; } + + local_def float16& operator/=(float16 rhs) { *this = (float)*this / (float)rhs; return *this; } + + template ::value>::type> + local_def float16& operator+=(const T& rhs) { *this = *this + rhs; return *this; } + + template ::value>::type> + local_def float16& operator-=(const T& rhs) { *this = *this - rhs; return *this; } + + template ::value>::type> + local_def float16& operator*=(const T& rhs) { *this = *this * rhs; return *this; } + + template ::value>::type> + local_def float16& operator/=(const T& rhs) { *this = *this / rhs; return *this; } + + local_def float16& operator++() { *this = *this + (float16)1.f; return *this; } + + local_def float16& operator--() { *this = *this - (float16)1.f; return *this; } + + local_def float16 operator++(int) { *this = *this + (float16)1.f; return *this; } + + local_def float16 operator--(int) { *this = *this - (float16)1.f; return *this; } + + local_def float16 operator-() const { + return 0.f - (float)*this; + } + + // local_def std::ostream& operator<<(std::ostream& os) { + // os << static_cast(*this); + // return os; + // } +}; + + + + // local_def std::ostream& operator<<(std::ostream &os, const float16 &f) { + // os << static_cast(f); + // return os; + // } + + // local_def float16 operator+(const float16& h) { return h; } + + // local_def float16 operator - (const float16& h) { + // const ihalf * tmp = &h.data; + // return float16(hneg(tmp->getX())); + // } #ifdef __CUDACC__ local_def int isnan(const float16& h) { return ishnan_(((ihalf)h.data).getX()); } @@ -730,6 +482,6 @@ local_def ihalf cpu_float2ihalf_rn(float f) local_def int isinf(const float16& h) { return ishinf_(((ihalf)h.data).getX()); } #endif - std::ostream& operator << (std::ostream& s, const float16&); + // std::ostream& operator << (std::ostream& s, const float16&); #endif diff --git a/libnd4j/pom.xml b/libnd4j/pom.xml index 374bc5640..20b9d6562 100644 --- a/libnd4j/pom.xml +++ b/libnd4j/pom.xml @@ -326,6 +326,8 @@ --compute ${libnd4j.compute} ${libnd4j.tests} + -j + ${libnd4j.buildthreads} ${project.basedir} diff --git a/libnd4j/tests_cpu/CMakeLists.txt b/libnd4j/tests_cpu/CMakeLists.txt index 3d58617b1..5de17a2d1 100644 --- a/libnd4j/tests_cpu/CMakeLists.txt +++ b/libnd4j/tests_cpu/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.6) +cmake_minimum_required(VERSION 3.15) project(tests_cpu) # Download and unpack googletest at configure time diff --git a/libnd4j/tests_cpu/CMakeLists.txt.in b/libnd4j/tests_cpu/CMakeLists.txt.in index 8bc138871..a3cba4d27 100644 --- a/libnd4j/tests_cpu/CMakeLists.txt.in +++ b/libnd4j/tests_cpu/CMakeLists.txt.in @@ -5,9 +5,10 @@ project(googletest-download NONE) include(ExternalProject) ExternalProject_Add(googletest GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG release-1.8.1 + GIT_TAG release-1.10.0 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-build" + CMAKE_ARGS "" CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND "" diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 238c2f15d..655683687 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -43,7 +43,7 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) { //exp.printIndexedBuffer("E B"); - exp.applyBroadcast(broadcast::Add, {1}, &y); + exp.applyBroadcast(broadcast::Add, {1}, y, exp); nd4j::ops::add op; auto result = op.execute({&x, &y}, {}, {}, {}); @@ -70,7 +70,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_1) { y.linspace(1); exp.linspace(1); - exp.applyBroadcast(broadcast::Multiply, {1}, &y); + exp.applyBroadcast(broadcast::Multiply, {1}, y, exp); nd4j::ops::multiply op; auto result = op.execute({&x, &y}, {}, {}, {}); @@ -94,7 +94,7 @@ TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) { y.linspace(1); exp.linspace(1); - exp.applyBroadcast(broadcast::SquaredSubtract, {1}, &y); + exp.applyBroadcast(broadcast::SquaredSubtract, {1}, y, exp); nd4j::ops::squaredsubtract op; @@ -856,7 +856,7 @@ TEST_F(BroadcastableOpsTests, test_bert_multiply_1) { z.printIndexedBuffer(); */ - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); //z.printIndexedBuffer(); @@ -874,7 +874,7 @@ TEST_F(BroadcastableOpsTests, test_bert_multiply_2) { z.assign(119.f); e.assign(2.f); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); ASSERT_EQ(e, z); } diff --git a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt index 52fa0ca17..f538eb9cd 100644 --- a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt @@ -30,31 +30,30 @@ if (CUDA_BLAS) if(WIN32) message("CUDA on Windows: enabling /EHsc") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /FS") - SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc") endif() if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) + set(CMAKE_CUDA_FLAGS " -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70") else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + set(CMAKE_CUDA_FLAGS " -DCUDA_10 ${EXPM} -w -G -g --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE}") endif() endif() # -fsanitize=address # -fsanitize=leak if (APPLE) - set(CMAKE_CXX_FLAGS " -fPIC -std=c++11 -fmax-errors=2 -D__APPLE_OS__=true") + set(CMAKE_CXX_FLAGS " -fPIC -fmax-errors=2 -D__APPLE_OS__=true") elseif(WIN32) if (CPU_BLAS) set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -fPIC -march=native -mtune=native -O3") endif() if (CPU_BLAS AND LINUX) - set(CMAKE_CXX_FLAGS " -fPIC -std=c++11 -fmax-errors=2") + set(CMAKE_CXX_FLAGS " -fPIC -fmax-errors=2") endif() else() set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") - set(CMAKE_CXX_FLAGS " -fPIC -std=c++11 -fmax-errors=2") + set(CMAKE_CXX_FLAGS " -fPIC -fmax-errors=2") if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*") set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native") else() @@ -68,14 +67,6 @@ else() endif() endif() -# TODO: get rid of this once problem confirmed solved -#if (APPLE) -# if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") -# if ("${CMAKE_C_COMPILER_VERSION}" VERSION_GREATER 6.0 OR "${CMAKE_C_COMPILER_VERSION}" VERSION_EQUAL 6.0) -# SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mavx512f -fmax-errors=1") -# endif() -# endif() -#endif() # tests are always compiled with all ops included SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLIBND4J_ALL_OPS=true -DBUILD_TESTS=true") @@ -141,6 +132,21 @@ if (CPU_BLAS) add_executable(runtests ${TEST_SOURCES}) target_link_libraries(runtests ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main) elseif(CUDA_BLAS) - CUDA_ADD_EXECUTABLE(runtests ${TEST_SOURCES}) - target_link_libraries(runtests ${LIBND4J_NAME} ${CUDA_LIBRARIES} gtest gtest_main) + + add_executable(runtests ${TEST_SOURCES}) + + if (WIN32) + message("MSVC runtime for tests: ${MSVC_RT_LIB}") + endif() + + # applies to windows only + set_property(TARGET runtests PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$:Debug>") + set_property(TARGET gtest PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$:Debug>") + set_property(TARGET gtest_main PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$:Debug>") + + if (HAVE_CUDNN) + message("CUDNN library: ${CUDNN}") + endif() + + target_link_libraries(runtests ${LIBND4J_NAME}static ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN} gtest gtest_main) endif() \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp index 9134ef0a4..e025aaead 100644 --- a/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp @@ -120,7 +120,7 @@ TEST_F(ConstantShapeHelperTests, basic_test_3) { TEST_F(ConstantShapeHelperTests, basic_test_4) { auto array = NDArrayFactory::create_('c', {128, 256}); - auto dup = array->dup('f'); + auto dup = new NDArray(array->dup('f')); ASSERT_TRUE(dup->shapeInfo() != nullptr); @@ -165,12 +165,11 @@ TEST_F(ConstantShapeHelperTests, basic_test_7) { IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)}); auto strided = array->subarray(indices); - strided->assign(1.0f); + strided.assign(1.0f); //strided->printIndexedBuffer("column"); delete array; - delete strided; } TEST_F(ConstantHelperTests, basic_test_1) { diff --git a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp index 42e141f46..13316fe8d 100644 --- a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp @@ -130,7 +130,7 @@ TEST_F(ContextTests, Basic_Test_5) { auto _20 = NDArrayFactory::create_('c', {2, 2}); _20->linspace(1); - auto exp = _20->dup(); + auto exp = new NDArray(_20->dup()); ctx.pushNDArrayToVariableSpace(1, 1, _20); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index eccb73c6c..9ed9f0ee6 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -308,7 +308,7 @@ TEST_F(ConvolutionTests1, conv2d_8) { auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto output = results->at(0); - // output->printIndexedBuffer(); + // output->printBuffer(); ASSERT_EQ(Status::OK(), results->status()); @@ -422,8 +422,8 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { TypeParam _expBFF[] = {108.9405008f, 109.5920008f, 110.2435008f, 110.8950008f, 111.5465008f, 112.1980008f, 115.4555008f, 116.1070008f, 116.7585008f, 117.410000f, 118.061500f, 118.7130009f, 121.9705009f, 122.6220009f, 123.2735009f, 123.9250009f, 124.5765009f, 125.2280009f, 128.4855009f, 129.1370009f, 129.7885009f, 130.4400009f, 131.09150f, 131.74300f, 135.0005010f, 135.6520010f, 136.3035010f, 136.9550010f, 137.6065010f, 138.2580010f, 141.5155010f, 142.1670010f, 142.8185010f, 143.4700010f, 144.1215010f, 144.7730010f, 248.9617514f, 250.670751f, 252.3797515f, 254.0887515f, 255.7977515f, 257.5067515f, 266.0517515f, 267.7607515f, 269.469751f, 271.1787516f, 272.8877516f, 274.5967516f, 283.1417516f, 284.8507516f, 286.5597516f, 288.268751f, 289.9777517f, 291.6867517f, 300.2317517f, 301.9407517f, 303.6497517f, 305.3587517f, 307.067751f, 308.7767518f, 317.3217518f, 319.0307518f, 320.7397518f, 322.4487518f, 324.157751f, 325.866751f, 334.4117519f, 336.1207519f, 337.8297519f, 339.5387519f, 341.2477519f, 342.95675f, 388.9829964f, 391.7494964f, 394.5159964f, 397.2824964f, 400.048996f, 402.8154963f, 416.647996f, 419.4144962f, 422.1809962f, 424.9474962f, 427.7139962f, 430.4804962f, 444.3129961f, 447.0794961f, 449.8459961f, 452.6124960f, 455.3789960f, 458.1454960f, 471.9779959f, 474.7444959f, 477.5109959f, 480.2774959f, 483.0439959f, 485.8104958f, 499.6429958f, 502.4094957f, 505.1759957f, 507.9424957f, 510.7089957f, 513.4754957f, 527.3079956f, 530.0744956f, 532.8409956f, 535.607495f, 538.3739955f, 541.1404955f, 529.0042487f, 532.8282487f, 536.6522487f, 540.4762487f, 544.3002487f, 548.1242487f, 567.2442487f, 571.068248f, 574.892248f, 578.716248f, 582.540248f, 586.3642486f, 605.4842486f, 609.3082486f, 613.1322486f, 616.9562486f, 620.7802486f, 624.6042486f, 643.7242486f, 647.5482486f, 651.3722486f, 655.1962486f, 659.0202486f, 662.8442486f, 681.9642486f, 685.7882486f, 689.6122486f, 693.4362486f, 697.2602486f, 701.0842486f, 720.2042486f, 724.0282486f, 727.852248f, 731.676248f, 735.500248f, 739.324248f, 669.0255044f, 673.9070044f, 678.7885044f, 683.6700044f, 688.5515044f, 693.4330044f, - 717.8405044f, 722.7220044f, 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, - 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, 1340.6589908f, 1343.4254908f, 1357.2579907f, + 717.8405044f, 722.7220044f, 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, + 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, 1340.6589908f, 1343.4254908f, 1357.2579907f, 1360.0244907f, 1362.7909906f, 1365.5574906f, 1368.3239906f, 1371.0904906f, 1676.2042479f, 1680.0282479f, 1683.8522479f, 1687.6762479f, 1691.5002479f, 1695.3242479f, 1714.4442479f, 1718.2682479f, 1722.0922479f, 1725.9162479f, 1729.7402479f, 1733.5642479f, 1752.6842479f, 1756.5082479f, 1760.3322479f, 1764.1562479f, 1767.9802479f, 1771.8042479f, 1790.9242479f, 1794.7482479f, 1798.5722479f, 1802.3962479f, 1806.2202479f, 1810.044247f, 1829.1642478f, 1832.9882478f, 1836.8122478f, 1840.6362478f, 1844.4602478f, 1848.2842478f, 1867.4042478f, 1871.2282478f, 1875.0522478f, 1878.8762478f, 1882.7002478f, 1886.5242478f, 2133.4755029f, 2138.3570029f, 2143.2385029f, 2148.1200029f, 2153.0015029f, 2157.8830029f, 2182.2905028f, 2187.1720028f, 2192.0535028f, 2196.9350028f, 2201.8165028f, 2206.6980028f, 2231.1055028f, 2235.9870028f, 2240.8685028f, 2245.7500028f, 2250.6315028f, 2255.5130028f, 2279.9205027f, 2284.8020027f, 2289.6835027f, 2294.5650027f, 2299.4465027f, 2304.3280027f, 2328.7355027f, 2333.6170027f, 2338.4985027f, 2343.3800027f, 2348.2615027f, 2353.1430027f, 2377.5505026f, 2382.4320026f, 2387.3135026f, 2392.1950026f, 2397.0765026f, 2401.9580026f, 2590.7467330f, 2596.6857330f, 2602.6247329f, 2608.5637329f, 2614.5027329f, 2620.441732f, 2650.1367327f, 2656.0757327f, 2662.0147326f, 2667.9537326f, 2673.8927326f, 2679.8317325f, 2709.5267324f, 2715.465732f, 2721.4047323f, 2727.3437323f, 2733.282732f, 2739.2217322f, 2768.9167321f, 2774.8557320f, 2780.7947320f, 2786.7337320f, 2792.6727319f, 2798.6117319f, 2828.306731f, 2834.2457317f, 2840.1847317f, 2846.1237317f, 2852.0627316f, 2858.0017316f, 2887.6967314f, 2893.6357314f, 2899.5747314f, 2905.5137313f, 2911.4527313f, 2917.3917313f, 3048.0179587f, 3055.0144586f, 3062.0109585f, 3069.0074584f, 3076.0039584f, 3083.0004583f, 3117.9829579f, 3124.9794578f, 3131.9759578f, 3138.9724577f, 3145.9689576f, 3152.9654575f, 3187.947957f, 3194.9444571f, 3201.9409570f, 3208.9374569f, 3215.933956f, 3222.9304568f, 3257.9129564f, 3264.9094563f, 3271.9059562f, 3278.9024562f, 3285.8989561f, 3292.8954560f, 3327.8779556f, 3334.874455f, 3341.8709555f, 3348.8674554f, 3355.8639553f, 3362.860455f, 3397.8429549f, 3404.8394548f, 3411.8359547f, 3418.8324546f, 3425.8289546f, 3432.8254545f, 3505.28927f, 3513.3432780f, 3521.3972781f, 3529.4512782f, 3537.5052782f, 3545.5592783f, 3585.8292787f, 3593.8832788f, 3601.9372788f, 3609.9912789f, 3618.0452790f, 3626.099279f, 3666.3692794f, 3674.4232795f, 3682.4772796f, 3690.5312796f, 3698.5852797f, 3706.6392798f, 3746.9092801f, 3754.9632802f, 3763.0172803f, 3771.0712804f, 3779.1252804f, 3787.1792805f, 3827.4492809f, 3835.50328f, 3843.5572810f, 3851.6112811f, 3859.6652812f, 3867.7192812f, 3907.9892816f, 3916.0432817f, 3924.097281f, @@ -443,9 +443,9 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { weightsD.permutei({2,3,1,0}); weightsP.permutei({2,3,1,0}); - input.applyScalar(scalar::Divide, 100.0); - weightsD.applyScalar(scalar::Divide, 100.0); - weightsP.applyScalar(scalar::Divide, 100.0); + input.applyScalar(scalar::Divide, 100.0, input); + weightsD.applyScalar(scalar::Divide, 100.0, weightsD); + weightsP.applyScalar(scalar::Divide, 100.0, weightsP); nd4j::ops::sconv2d op; @@ -635,25 +635,63 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) { } TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { - TypeParam _expBFF[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 53150.0f, 55350.0f, 57550.0f, 59750.0f, 61950.0f, 64150.0f, 75150.0f, 77350.0f, 79550.0f, 81750.0f, 83950.0f, 86150.0f, 97150.0f, 99350.0f, 101550.0f, 103750.0f, 105950.0f, 108150.0f, 119150.0f, 121350.0f, 123550.0f, 125750.0f, 127950.0f, 130150.0f, 141150.0f, 143350.0f, 145550.0f, 147750.0f, 149950.0f, 152150.0f, 163150.0f, 165350.0f, 167550.0f, 169750.0f, 171950.0f, 174150.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 350025.0f, 352850.0f, 355675.0f, 358500.0f, 361325.0f, 364150.0f, 378275.0f, 381100.0f, 383925.0f, 386750.0f, 389575.0f, 392400.0f, 406525.0f, 409350.0f, 412175.0f, 415000.0f, 417825.0f, 420650.0f, 434775.0f, 437600.0f, 440425.0f, 443250.0f, 446075.0f, 448900.0f, 463025.0f, 465850.0f, 468675.0f, 471500.0f, 474325.0f, 477150.0f, 491275.0f, 494100.0f, 496925.0f, 499750.0f, 502575.0f, 505400.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 771900.0f, 775350.0f, 778800.0f, 782250.0f, 785700.0f, 789150.0f, 806400.0f, 809850.0f, 813300.0f, 816750.0f, 820200.0f, 823650.0f, 840900.0f, 844350.0f, 847800.0f, 851250.0f, 854700.0f, 858150.0f, 875400.0f, 878850.0f, 882300.0f, 885750.0f, 889200.0f, 892650.0f, 909900.0f, 913350.0f, 916800.0f, 920250.0f, 923700.0f, 927150.0f, 944400.0f, 947850.0f, 951300.0f, 954750.0f, 958200.0f, 961650.0f, 107525.0f, 107850.0f, 108175.0f, 108500.0f, 108825.0f, 109150.0f, 110775.0f, 111100.0f, 111425.0f, 111750.0f, 112075.0f, 112400.0f, 114025.0f, 114350.0f, 114675.0f, 115000.0f, 115325.0f, 115650.0f, 117275.0f, 117600.0f, 117925.0f, 118250.0f, 118575.0f, 118900.0f, 120525.0f, 120850.0f, 121175.0f, 121500.0f, 121825.0f, 122150.0f, 123775.0f, 124100.0f, 124425.0f, 124750.0f, 125075.0f, 125400.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 404400.0f, 405350.0f, 406300.0f, 407250.0f, 408200.0f, 409150.0f, 413900.0f, 414850.0f, 415800.0f, 416750.0f, 417700.0f, 418650.0f, 423400.0f, 424350.0f, 425300.0f, 426250.0f, 427200.0f, 428150.0f, 432900.0f, 433850.0f, 434800.0f, 435750.0f, 436700.0f, 437650.0f, 442400.0f, 443350.0f, 444300.0f, 445250.0f, 446200.0f, 447150.0f, 451900.0f, 452850.0f, 453800.0f, 454750.0f, 455700.0f, 456650.0f, 1197525.0f, 1200350.0f, 1203175.0f, 1206000.0f, 1208825.0f, 1211650.0f, 1225775.0f, 1228600.0f, 1231425.0f, 1234250.0f, 1237075.0f, 1239900.0f, 1254025.0f, 1256850.0f, 1259675.0f, 1262500.0f, 1265325.0f, 1268150.0f, 1282275.0f, 1285100.0f, 1287925.0f, 1290750.0f, 1293575.0f, 1296400.0f, 1310525.0f, 1313350.0f, 1316175.0f, 1319000.0f, 1321825.0f, 1324650.0f, 1338775.0f, 1341600.0f, 1344425.0f, 1347250.0f, 1350075.0f, 1352900.0f, 826275.0f, 827850.0f, 829425.0f, 831000.0f, 832575.0f, 834150.0f, 842025.0f, 843600.0f, 845175.0f, 846750.0f, 848325.0f, 849900.0f, 857775.0f, 859350.0f, 860925.0f, 862500.0f, 864075.0f, 865650.0f, 873525.0f, 875100.0f, 876675.0f, 878250.0f, 879825.0f, 881400.0f, 889275.0f, 890850.0f, 892425.0f, 894000.0f, 895575.0f, 897150.0f, 905025.0f, 906600.0f, 908175.0f, 909750.0f, 911325.0f, 912900.0f, 1806900.0f, 1810350.0f, 1813800.0f, 1817250.0f, 1820700.0f, 1824150.0f, 1841400.0f, 1844850.0f, 1848300.0f, 1851750.0f, 1855200.0f, 1858650.0f, 1875900.0f, 1879350.0f, 1882800.0f, 1886250.0f, 1889700.0f, 1893150.0f, 1910400.0f, 1913850.0f, 1917300.0f, 1920750.0f, 1924200.0f, 1927650.0f, 1944900.0f, 1948350.0f, 1951800.0f, 1955250.0f, 1958700.0f, 1962150.0f, 1979400.0f, 1982850.0f, 1986300.0f, 1989750.0f, 1993200.0f, 1996650.f}; - Nd4jLong _expSFF[] = {4, 2, 6, 6, 6, 216, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,}; - NDArray expFF(_expBFF, _expSFF); - TypeParam _exp2BFF[] = {827.4900282f, 832.2350283f, 836.9800284f, 841.725028f, 846.4700287f, 851.2150288f, 874.9400293f, 879.6850294f, 884.4300295f, 889.1750296f, 893.9200297f, 898.665029f, 922.3900304f, 927.1350305f, 931.8800306f, 936.6250307f, 941.3700308f, 946.1150309f, 969.8400315f, 974.5850316f, 979.3300317f, 984.0750318f, 988.8200319f, 993.5650320f, 1017.2900326f, 1022.0350327f, 1026.7800328f, 1031.5250329f, 1036.2700330f, 1041.0150331f, 1064.7400337f, 1069.4850338f, 1074.2300339f, 1078.9750340f, 1083.7200341f, 1088.4650342f, 1822.4550553f, 1833.995055f, 1845.5350558f, 1857.075056f, 1868.6150563f, 1880.1550566f, 1937.8550578f, 1949.3950581f, 1960.9350583f, 1972.4750586f, 1984.015058f, 1995.5550591f, 2053.2550604f, 2064.7950606f, 2076.3350609f, 2087.8750611f, 2099.4150614f, 2110.955061f, 2168.6550629f, 2180.1950632f, 2191.7350634f, 2203.2750637f, 2214.8150639f, 2226.3550642f, 2284.0550655f, 2295.5950657f, 2307.1350660f, 2318.6750662f, 2330.2150665f, 2341.7550667f, 2399.4550680f, 2410.9950683f, 2422.5350685f, 2434.0750688f, 2445.6150690f, 2457.1550693f, 2817.419968f, 2835.7549686f, 2854.0899683f, 2872.4249680f, 2890.7599677f, 2909.0949674f, 3000.7699660f, 3019.104965f, 3037.4399655f, 3055.7749652f, 3074.1099649f, 3092.4449646f, 3184.1199632f, 3202.4549629f, 3220.789962f, 3239.1249624f, 3257.4599621f, 3275.7949618f, 3367.4699604f, 3385.8049601f, 3404.1399598f, 3422.474959f, 3440.8099593f, 3459.1449590f, 3550.8199576f, 3569.1549573f, 3587.4899570f, 3605.8249567f, 3624.1599565f, 3642.4949562f, 3734.1699548f, 3752.5049545f, 3770.8399542f, 3789.1749539f, 3807.5099536f, 3825.8449534f, 3812.385098f, 3837.5150988f, 3862.6450994f, 3887.7751000f, 3912.9051006f, 3938.0351012f, 4063.6851041f, 4088.8151047f, 4113.9451053f, 4139.0751059f, 4164.2051065f, 4189.3351071f, 4314.9851100f, 4340.1151106f, 4365.2451112f, 4390.3751118f, 4415.5051124f, 4440.6351130f, 4566.2851159f, 4591.4151165f, 4616.5451171f, 4641.6751177f, 4666.805118f, 4691.9351188f, 4817.5851218f, 4842.7151224f, 4867.8451230f, 4892.975123f, 4918.1051241f, 4943.2351247f, 5068.8851277f, 5094.0151283f, 5119.1451288f, 5144.2751294f, 5169.4051300f, 5194.5351306f, 4807.3499803f, 4839.2749801f, 4871.1999799f, 4903.1249797f, 4935.0499795f, 4966.9749793f, 5126.5999784f, 5158.5249782f, 5190.4499780f, 5222.3749778f, 5254.2999777f, 5286.2249775f, 5445.8499765f, 5477.774976f, 5509.6999762f, 5541.6249760f, 5573.5499758f, 5605.4749756f, 5765.0999747f, 5797.0249745f, 5828.9499743f, 5860.8749741f, 5892.7999739f, 5924.724973f, 6084.3499728f, 6116.2749726f, 6148.1999724f, 6180.1249723f, 6212.0499721f, 6243.9749719f, 6403.59997f, 6435.5249708f, 6467.4499706f, 6499.3749704f, 6531.2999702f, 6563.2249700f, 5802.3150007f, 5841.0350006f, 5879.7550005f, 5918.4750004f, 5957.195000f, 5995.9150003f, 6189.5149999f, 6228.2349998f, 6266.9549997f, 6305.6749996f, 6344.3949995f, 6383.114999f, 6576.7149990f, 6615.4349990f, 6654.1549989f, 6692.8749988f, 6731.5949987f, 6770.3149986f, 6963.9149982f, 7002.6349981f, 7041.3549981f, 7080.0749980f, 7118.7949979f, 7157.5149978f, 7351.1149974f, 7389.8349973f, 7428.5549972f, 7467.2749972f, 7505.9949971f, 7544.7149970f, 7738.3149966f, 7777.0349965f, 7815.7549964f, 7854.4749963f, 7893.1949963f, 7931.9149962f, 6797.2799488f, 6842.794948f, 6888.3099489f, 6933.8249490f, 6979.3399491f, 7024.8549492f, 7252.4299497f, 7297.9449498f, 7343.4599499f, 7388.9749500f, 7434.489950f, 7480.0049501f, 7707.5799506f, 7753.0949507f, 7798.6099508f, 7844.1249509f, 7889.6399510f, 7935.1549511f, 8162.7299515f, 8208.2449516f, 8253.7599517f, 8299.2749518f, 8344.7899519f, 8390.3049520f, 8617.8799525f, 8663.394952f, 8708.9099526f, 8754.4249527f, 8799.9399528f, 8845.4549529f, 9073.0299534f, 9118.5449535f, 9164.0599536f, 9209.5749537f, 9255.089953f, 9300.604953f, 7792.2451647f, 7844.5551655f, 7896.8651663f, 7949.1751671f, 8001.4851679f, 8053.7951686f, 8315.3451725f, 8367.6551733f, 8419.9651741f, 8472.2751749f, 8524.585175f, 8576.8951764f, 8838.4451803f, 8890.7551811f, 8943.0651819f, 8995.3751827f, 9047.6851834f, 9099.9951842f, 9361.5451881f, 9413.8551889f, 9466.1651897f, 9518.475190f, 9570.7851912f, 9623.0951920f, 9884.6451959f, 9936.9551967f, 9989.2651975f, 10041.5751982f, 10093.8851990f, 10146.1951998f, 10407.7452037f, 10460.0552045f, 10512.3652053f, 10564.6752060f, 10616.9852068f, 10669.2952076f, 8787.210074f, 8846.3150748f, 8905.4200750f, 8964.5250752f, 9023.6300755f, 9082.7350757f, 9378.2600768f, 9437.3650770f, 9496.4700773f, 9555.5750775f, 9614.6800777f, 9673.7850779f, 9969.3100791f, 10028.4150793f, 10087.5200795f, 10146.625079f, 10205.7300800f, 10264.8350802f, 10560.3600813f, 10619.465081f, 10678.5700818f, 10737.6750820f, 10796.7800822f, 10855.8850825f, 11151.4100836f, 11210.5150838f, 11269.6200840f, 11328.7250843f, 11387.8300845f, 11446.9350847f, 11742.4600858f, 11801.5650861f, 11860.6700863f, 11919.7750865f, 11978.880086f, 12037.9850870f, 9782.1750935f, 9848.0750935f, 9913.9750934f, 9979.8750934f, 10045.7750934f, 10111.6750933f, 10441.1750931f, 10507.0750931f, 10572.9750931f, 10638.8750930f, 10704.7750930f, 10770.6750930f, 11100.1750928f, 11166.0750927f, 11231.9750927f, 11297.8750927f, 11363.7750926f, 11429.6750926f, 11759.1750924f, 11825.0750924f, 11890.9750923f, 11956.8750923f, 12022.7750923f, 12088.6750922f, 12418.175092f, 12484.0750920f, 12549.9750920f, 12615.8750919f, 12681.7750919f, 12747.6750919f, 13077.1750917f, 13143.0750916f, 13208.9750916f, 13274.8750916f, 13340.7750915f, 13406.6750915f, 2250.990060f, 2255.7350610f, 2260.4800611f, 2265.2250612f, 2269.9700613f, 2274.7150614f, 2298.4400619f, 2303.185062f, 2307.9300622f, 2312.6750623f, 2317.4200624f, 2322.1650625f, 2345.8900630f, 2350.6350631f, 2355.380063f, 2360.1250634f, 2364.8700635f, 2369.6150636f, 2393.3400641f, 2398.0850642f, 2402.8300643f, 2407.5750644f, 2412.320064f, 2417.0650647f, 2440.7900652f, 2445.5350653f, 2450.2800654f, 2455.0250655f, 2459.7700656f, 2464.515065f, 2488.2400663f, 2492.9850664f, 2497.7300665f, 2502.4750666f, 2507.2200667f, 2511.9650668f, 5284.4551315f, 5295.9951318f, 5307.535132f, 5319.0751323f, 5330.6151326f, 5342.1551328f, 5399.8551341f, 5411.3951343f, 5422.9351346f, 5434.475134f, 5446.0151351f, 5457.5551354f, 5515.2551366f, 5526.7951369f, 5538.3351371f, 5549.8751374f, 5561.4151376f, 5572.9551379f, 5630.6551392f, 5642.1951394f, 5653.7351397f, 5665.2751399f, 5676.8151402f, 5688.3551404f, 5746.0551417f, 5757.5951420f, 5769.1351422f, 5780.6751425f, 5792.2151427f, 5803.7551430f, 5861.455144f, 5872.9951445f, 5884.5351448f, 5896.0751450f, 5907.6151453f, 5919.1551455f, 8317.919884f, 8336.2548841f, 8354.5898838f, 8372.9248835f, 8391.2598832f, 8409.59488f, 8501.2698815f, 8519.6048813f, 8537.9398810f, 8556.2748807f, 8574.6098804f, 8592.9448801f, 8684.6198787f, 8702.9548784f, 8721.2898782f, 8739.6248779f, 8757.9598776f, 8776.2948773f, 8867.9698759f, 8886.3048756f, 8904.6398753f, 8922.9748751f, 8941.3098748f, 8959.6448745f, 9051.3198731f, 9069.6548728f, 9087.9898725f, 9106.3248722f, 9124.6598720f, 9142.9948717f, 9234.6698703f, 9253.0048700f, 9271.3398697f, 9289.6748694f, 9308.0098691f, 9326.3448689f, 11351.3852747f, 11376.5152753f, 11401.6452759f, 11426.7752765f, 11451.9052771f, 11477.0352777f, 11602.6852806f, 11627.8152812f, 11652.9452818f, 11678.0752824f, 11703.2052830f, 11728.335283f, 11853.9852865f, 11879.1152871f, 11904.2452877f, 11929.3752883f, 11954.505288f, 11979.6352894f, 12105.2852924f, 12130.4152930f, 12155.545293f, 12180.6752941f, 12205.8052947f, 12230.9352953f, 12356.5852983f, 12381.715298f, 12406.8452994f, 12431.9753000f, 12457.1053006f, 12482.2353012f, 12607.8853041f, 12633.0153047f, 12658.1453053f, 12683.2753059f, 12708.4053065f, 12733.5353071f, 14384.8499244f, 14416.7749242f, 14448.6999240f, 14480.6249238f, 14512.549923f, 14544.4749235f, 14704.0999225f, 14736.024922f, 14767.9499222f, 14799.8749220f, 14831.7999218f, 14863.7249216f, 15023.3499207f, 15055.2749205f, 15087.1999203f, 15119.1249201f, 15151.0499199f, 15182.9749197f, 15342.5999188f, 15374.5249186f, 15406.4499184f, 15438.374918f, 15470.2999181f, 15502.2249179f, 15661.84991f, 15693.7749168f, 15725.6999166f, 15757.6249164f, 15789.5499162f, 15821.4749160f, 15981.0999151f, 16013.0249149f, 16044.9499147f, 16076.8749145f, 16108.7999143f, 16140.7249142f, 17418.314976f, 17457.0349761f, 17495.7549760f, 17534.4749759f, 17573.1949758f, 17611.9149757f, 17805.5149753f, 17844.234975f, 17882.9549752f, 17921.6749751f, 17960.3949750f, 17999.1149749f, 18192.7149745f, 18231.4349744f, 18270.154974f, 18308.8749743f, 18347.5949742f, 18386.3149741f, 18579.9149737f, 18618.6349736f, 18657.3549735f, 18696.074973f, 18734.7949734f, 18773.5149733f, 18967.1149729f, 19005.8349728f, 19044.5549727f, 19083.2749726f, 19121.994972f, 19160.7149725f, 19354.3149721f, 19393.0349720f, 19431.7549719f, 19470.4749718f, 19509.1949717f, 19547.914971f, 20451.7799765f, 20497.2949766f, 20542.8099767f, 20588.3249768f, 20633.8399769f, 20679.3549770f, 20906.929977f, 20952.4449775f, 20997.9599776f, 21043.4749777f, 21088.9899778f, 21134.5049779f, 21362.0799784f, 21407.5949785f, 21453.1099786f, 21498.624978f, 21544.139978f, 21589.6549788f, 21817.2299793f, 21862.7449794f, 21908.2599795f, 21953.7749796f, 21999.2899797f, 22044.8049798f, 22272.3799802f, 22317.8949803f, 22363.4099804f, 22408.9249805f, 22454.4399806f, 22499.9549807f, 22727.529981f, 22773.044981f, 22818.5599813f, 22864.0749814f, 22909.5899815f, 22955.1049816f, 23485.2453985f, 23537.555399f, 23589.8654000f, 23642.1754008f, 23694.4854016f, 23746.7954024f, 24008.3454063f, 24060.655407f, 24112.9654078f, 24165.2754086f, 24217.5854094f, 24269.8954102f, 24531.4454141f, 24583.7554148f, 24636.0654156f, 24688.3754164f, 24740.6854172f, 24792.99541f, 25054.545421f, 25106.8554226f, 25159.1654234f, 25211.4754242f, 25263.7854250f, 25316.0954257f, 25577.6454296f, 25629.9554304f, 25682.2654312f, 25734.5754320f, 25786.8854328f, 25839.1954335f, 26100.7454374f, 26153.0554382f, 26205.3654390f, 26257.6754398f, 26309.985440f, 26362.2954413f, 26518.7101423f, 26577.8151425f, 26636.920142f, 26696.0251430f, 26755.1301432f, 26814.2351434f, 27109.7601446f, 27168.8651448f, 27227.9701450f, 27287.0751452f, 27346.1801455f, 27405.2851457f, 27700.8101468f, 27759.9151470f, 27819.0201473f, 27878.1251475f, 27937.2301477f, 27996.33514f, 28291.8601491f, 28350.9651493f, 28410.0701495f, 28469.175149f, 28528.2801500f, 28587.3851502f, 28882.9101513f, 28942.0151516f, 29001.1201518f, 29060.2251520f, 29119.3301522f, 29178.4351525f, 29473.9601536f, 29533.0651538f, 29592.1701540f, 29651.2751543f, 29710.3801545f, 29769.4851547f, 29552.1750826f, 29618.0750825f, 29683.9750825f, 29749.8750825f, 29815.7750824f, 29881.6750824f, 30211.1750822f, 30277.0750822f, 30342.9750821f, 30408.8750821f, 30474.7750821f, 30540.6750820f, 30870.175081f, 30936.0750818f, 31001.9750818f, 31067.8750817f, 31133.7750817f, 31199.6750817f, 31529.1750815f, 31595.075081f, 31660.9750814f, 31726.8750814f, 31792.7750813f, 31858.6750813f, 32188.1750811f, 32254.0750811f, 32319.975081f, 32385.8750810f, 32451.7750810f, 32517.6750809f, 32847.1750808f, 32913.0750807f, 32978.9750807f, 33044.875080f, 33110.7750806f, 33176.67508062f}; - Nd4jLong _exp2SFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray exp2FF(_exp2BFF, _exp2SFF); auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); - auto weightsD = NDArrayFactory::create('c', {2, 3, 5, 5}); - auto weightsP = NDArrayFactory::create('c', {10, 6, 1, 1}); + auto weightsD = NDArrayFactory::create('c', {5, 5, 3, 2}, {1.f, 76.f, 26.f, 101.f, 51.f, 126.f, 2.f, 77.f, 27.f, 102.f, 52.f, 127.f, 3.f, 78.f, 28.f, 103.f, 53.f, 128.f, 4.f, 79.f, 29.f, 104.f, 54.f, 129.f, 5.f, 80.f, 30.f, 105.f, 55.f, 130.f, + 6.f, 81.f, 31.f, 106.f, 56.f, 131.f, 7.f, 82.f, 32.f, 107.f, 57.f, 132.f, 8.f, 83.f, 33.f, 108.f, 58.f, 133.f, 9.f, 84.f, 34.f, 109.f, 59.f, 134.f, 10.f, 85.f, 35.f, 110.f, 60.f, 135.f, + 11.f, 86.f, 36.f, 111.f, 61.f, 136.f, 12.f, 87.f, 37.f, 112.f, 62.f, 137.f, 13.f, 88.f, 38.f, 113.f, 63.f, 138.f, 14.f, 89.f, 39.f, 114.f, 64.f, 139.f, 15.f, 90.f, 40.f, 115.f, 65.f, 140.f, + 16.f, 91.f, 41.f, 116.f, 66.f, 141.f, 17.f, 92.f, 42.f, 117.f, 67.f, 142.f, 18.f, 93.f, 43.f, 118.f, 68.f, 143.f, 19.f, 94.f, 44.f, 119.f, 69.f, 144.f, 20.f, 95.f, 45.f, 120.f, 70.f, 145.f, + 21.f, 96.f, 46.f, 121.f, 71.f, 146.f, 22.f, 97.f, 47.f, 122.f, 72.f, 147.f, 23.f, 98.f, 48.f, 123.f, 73.f, 148.f, 24.f, 99.f, 49.f, 124.f, 74.f, 149.f, 25.f, 100.f, 50.f, 125.f, 75.f, 150.f}); + auto weightsP = NDArrayFactory::create('c', {1, 1, 6, 10}, {0.0001f, 0.0007f, 0.0013f, 0.0019f, 0.0025f, 0.0031f, 0.0037f, 0.0043f, 0.0049f, 0.0055f,0.0002f, 0.0008f, 0.0014f, 0.0020f, 0.0026f, 0.0032f, 0.0038f, 0.0044f, 0.0050f, 0.0056f, + 0.0003f, 0.0009f, 0.0015f, 0.0021f, 0.0027f, 0.0033f, 0.0039f, 0.0045f, 0.0051f, 0.0057f,0.0004f, 0.0010f, 0.0016f, 0.0022f, 0.0028f, 0.0034f, 0.0040f, 0.0046f, 0.0052f, 0.0058f, + 0.0005f, 0.0011f, 0.0017f, 0.0023f, 0.0029f, 0.0035f, 0.0041f, 0.0047f, 0.0053f, 0.0059f,0.0006f, 0.0012f, 0.0018f, 0.0024f, 0.0030f, 0.0036f, 0.0042f, 0.0048f, 0.0054f, 0.0060f}); + auto expFF = NDArrayFactory::create('c', {2, 6, 6, 6}, {10025.0f,10350.0f,10675.0f,11000.0f,11325.0f,11650.0f,13275.0f,13600.0f,13925.0f,14250.0f,14575.0f,14900.0f,16525.0f,16850.0f, + 17175.0f,17500.0f,17825.0f,18150.0f,19775.0f,20100.0f,20425.0f,20750.0f,21075.0f,21400.0f,23025.0f,23350.0f,23675.0f,24000.0f, + 24325.0f,24650.0f,26275.0f,26600.0f,26925.0f,27250.0f,27575.0f,27900.0f,53150.0f,55350.0f,57550.0f,59750.0f,61950.0f,64150.0f, + 75150.0f,77350.0f,79550.0f,81750.0f,83950.0f,86150.0f,97150.0f,99350.0f,101550.0f,103750.0f,105950.0f,108150.0f,119150.0f, + 121350.0f,123550.0f,125750.0f,127950.0f,130150.0f,141150.0f,143350.0f,145550.0f,147750.0f,149950.0f,152150.0f,163150.0f, + 165350.0f,167550.0f,169750.0f,171950.0f,174150.0f,119400.0f,120350.0f,121300.0f,122250.0f,123200.0f,124150.0f,128900.0f, + 129850.0f,130800.0f,131750.0f,132700.0f,133650.0f,138400.0f,139350.0f,140300.0f,141250.0f,142200.0f,143150.0f,147900.0f, + 148850.0f,149800.0f,150750.0f,151700.0f,152650.0f,157400.0f,158350.0f,159300.0f,160250.0f,161200.0f,162150.0f,166900.0f, + 167850.0f,168800.0f,169750.0f,170700.0f,171650.0f,350025.0f,352850.0f,355675.0f,358500.0f,361325.0f,364150.0f,378275.0f, + 381100.0f,383925.0f,386750.0f,389575.0f,392400.0f,406525.0f,409350.0f,412175.0f,415000.0f,417825.0f,420650.0f,434775.0f, + 437600.0f,440425.0f,443250.0f,446075.0f,448900.0f,463025.0f,465850.0f,468675.0f,471500.0f,474325.0f,477150.0f,491275.0f, + 494100.0f,496925.0f,499750.0f,502575.0f,505400.0f,353775.0f,355350.0f,356925.0f,358500.0f,360075.0f,361650.0f,369525.0f, + 371100.0f,372675.0f,374250.0f,375825.0f,377400.0f,385275.0f,386850.0f,388425.0f,390000.0f,391575.0f,393150.0f,401025.0f, + 402600.0f,404175.0f,405750.0f,407325.0f,408900.0f,416775.0f,418350.0f,419925.0f,421500.0f,423075.0f,424650.0f,432525.0f, + 434100.0f,435675.0f,437250.0f,438825.0f,440400.0f,771900.0f,775350.0f,778800.0f,782250.0f,785700.0f,789150.0f,806400.0f, + 809850.0f,813300.0f,816750.0f,820200.0f,823650.0f,840900.0f,844350.0f,847800.0f,851250.0f,854700.0f,858150.0f,875400.0f, + 878850.0f,882300.0f,885750.0f,889200.0f,892650.0f,909900.0f,913350.0f,916800.0f,920250.0f,923700.0f,927150.0f,944400.0f, + 947850.0f,951300.0f,954750.0f,958200.0f,961650.0f,107525.0f,107850.0f,108175.0f,108500.0f,108825.0f,109150.0f,110775.0f, + 111100.0f,111425.0f,111750.0f,112075.0f,112400.0f,114025.0f,114350.0f,114675.0f,115000.0f,115325.0f,115650.0f,117275.0f, + 117600.0f,117925.0f,118250.0f,118575.0f,118900.0f,120525.0f,120850.0f,121175.0f,121500.0f,121825.0f,122150.0f,123775.0f, + 124100.0f,124425.0f,124750.0f,125075.0f,125400.0f,713150.0f,715350.0f,717550.0f,719750.0f,721950.0f,724150.0f,735150.0f, + 737350.0f,739550.0f,741750.0f,743950.0f,746150.0f,757150.0f,759350.0f,761550.0f,763750.0f,765950.0f,768150.0f,779150.0f, + 781350.0f,783550.0f,785750.0f,787950.0f,790150.0f,801150.0f,803350.0f,805550.0f,807750.0f,809950.0f,812150.0f,823150.0f, + 825350.0f,827550.0f,829750.0f,831950.0f,834150.0f,404400.0f,405350.0f,406300.0f,407250.0f,408200.0f,409150.0f,413900.0f, + 414850.0f,415800.0f,416750.0f,417700.0f,418650.0f,423400.0f,424350.0f,425300.0f,426250.0f,427200.0f,428150.0f,432900.0f,433850.0f,434800.0f,435750.0f,436700.0f,437650.0f,442400.0f,443350.0f,444300.0f,445250.0f,446200.0f,447150.0f,451900.0f,452850.0f,453800.0f,454750.0f,455700.0f,456650.0f,1197525.0f,1200350.0f,1203175.0f,1206000.0f,1208825.0f,1211650.0f,1225775.0f,1228600.0f,1231425.0f,1234250.0f,1237075.0f,1239900.0f,1254025.0f,1256850.0f,1259675.0f,1262500.0f,1265325.0f,1268150.0f,1282275.0f,1285100.0f,1287925.0f,1290750.0f,1293575.0f,1296400.0f,1310525.0f,1313350.0f,1316175.0f,1319000.0f,1321825.0f,1324650.0f,1338775.0f,1341600.0f,1344425.0f,1347250.0f,1350075.0f,1352900.0f,826275.0f,827850.0f,829425.0f,831000.0f,832575.0f,834150.0f,842025.0f,843600.0f,845175.0f,846750.0f,848325.0f,849900.0f,857775.0f,859350.0f,860925.0f,862500.0f,864075.0f,865650.0f,873525.0f,875100.0f,876675.0f,878250.0f,879825.0f,881400.0f,889275.0f,890850.0f,892425.0f,894000.0f,895575.0f,897150.0f,905025.0f,906600.0f,908175.0f,909750.0f,911325.0f,912900.0f,1806900.0f,1810350.0f,1813800.0f,1817250.0f,1820700.0f,1824150.0f,1841400.0f,1844850.0f,1848300.0f,1851750.0f,1855200.0f,1858650.0f,1875900.0f,1879350.0f,1882800.0f,1886250.0f,1889700.0f,1893150.0f,1910400.0f,1913850.0f,1917300.0f,1920750.0f,1924200.0f,1927650.0f,1944900.0f,1948350.0f,1951800.0f,1955250.0f,1958700.0f,1962150.0f,1979400.0f,1982850.0f,1986300.0f,1989750.0f,1993200.0f,1996650.f}); + auto exp2FF = NDArrayFactory::create('c', {2, 10, 6, 6}, {827.4900282f,832.2350283f,836.9800284f,841.725028f,846.4700287f,851.2150288f,874.9400293f,879.6850294f,884.4300295f,889.1750296f,893.9200297f,898.665029f, + 922.3900304f,927.1350305f,931.8800306f,936.6250307f,941.3700308f,946.1150309f,969.8400315f,974.5850316f,979.3300317f,984.0750318f,988.8200319f,993.5650320f, + 1017.2900326f,1022.0350327f,1026.7800328f,1031.5250329f,1036.2700330f,1041.0150331f,1064.7400337f,1069.4850338f,1074.2300339f,1078.9750340f,1083.7200341f, + 1088.4650342f,1822.4550553f,1833.995055f,1845.5350558f,1857.075056f,1868.6150563f,1880.1550566f,1937.8550578f,1949.3950581f,1960.9350583f,1972.4750586f, + 1984.015058f,1995.5550591f,2053.2550604f,2064.7950606f,2076.3350609f,2087.8750611f,2099.4150614f,2110.955061f,2168.6550629f,2180.1950632f,2191.7350634f, + 2203.2750637f,2214.8150639f,2226.3550642f,2284.0550655f,2295.5950657f,2307.1350660f,2318.6750662f,2330.2150665f,2341.7550667f,2399.4550680f,2410.9950683f, + 2422.5350685f,2434.0750688f,2445.6150690f,2457.1550693f,2817.419968f,2835.7549686f,2854.0899683f,2872.4249680f,2890.7599677f,2909.0949674f,3000.7699660f, + 3019.104965f,3037.4399655f,3055.7749652f,3074.1099649f,3092.4449646f,3184.1199632f,3202.4549629f,3220.789962f,3239.1249624f,3257.4599621f,3275.7949618f, + 3367.4699604f,3385.8049601f,3404.1399598f,3422.474959f,3440.8099593f,3459.1449590f,3550.8199576f,3569.1549573f,3587.4899570f,3605.8249567f,3624.1599565f, + 3642.4949562f,3734.1699548f,3752.5049545f,3770.8399542f,3789.1749539f,3807.5099536f,3825.8449534f,3812.385098f,3837.5150988f,3862.6450994f,3887.7751000f, + 3912.9051006f,3938.0351012f,4063.6851041f,4088.8151047f,4113.9451053f,4139.0751059f,4164.2051065f,4189.3351071f,4314.9851100f,4340.1151106f,4365.2451112f, + 4390.3751118f,4415.5051124f,4440.6351130f,4566.2851159f,4591.4151165f,4616.5451171f,4641.6751177f,4666.805118f,4691.9351188f,4817.5851218f,4842.7151224f, + 4867.8451230f,4892.975123f,4918.1051241f,4943.2351247f,5068.8851277f,5094.0151283f,5119.1451288f,5144.2751294f,5169.4051300f,5194.5351306f,4807.3499803f, + 4839.2749801f,4871.1999799f,4903.1249797f,4935.0499795f,4966.9749793f,5126.5999784f,5158.5249782f,5190.4499780f,5222.3749778f,5254.2999777f,5286.2249775f, + 5445.8499765f,5477.774976f,5509.6999762f,5541.6249760f,5573.5499758f,5605.4749756f,5765.0999747f,5797.0249745f,5828.9499743f,5860.8749741f,5892.7999739f, + 5924.724973f,6084.3499728f,6116.2749726f,6148.1999724f,6180.1249723f,6212.0499721f,6243.9749719f,6403.59997f,6435.5249708f,6467.4499706f,6499.3749704f, + 6531.2999702f,6563.2249700f,5802.3150007f,5841.0350006f,5879.7550005f,5918.4750004f,5957.195000f,5995.9150003f,6189.5149999f,6228.2349998f,6266.9549997f, + 6305.6749996f,6344.3949995f,6383.114999f,6576.7149990f,6615.4349990f,6654.1549989f,6692.8749988f,6731.5949987f,6770.3149986f,6963.9149982f,7002.6349981f, + 7041.3549981f,7080.0749980f,7118.7949979f,7157.5149978f,7351.1149974f,7389.8349973f,7428.5549972f,7467.2749972f,7505.9949971f,7544.7149970f,7738.3149966f,7777.0349965f,7815.7549964f,7854.4749963f,7893.1949963f,7931.9149962f,6797.2799488f,6842.794948f,6888.3099489f,6933.8249490f,6979.3399491f,7024.8549492f,7252.4299497f,7297.9449498f,7343.4599499f,7388.9749500f,7434.489950f,7480.0049501f,7707.5799506f,7753.0949507f,7798.6099508f,7844.1249509f,7889.6399510f,7935.1549511f,8162.7299515f,8208.2449516f,8253.7599517f,8299.2749518f,8344.7899519f,8390.3049520f,8617.8799525f,8663.394952f,8708.9099526f,8754.4249527f,8799.9399528f,8845.4549529f,9073.0299534f,9118.5449535f,9164.0599536f,9209.5749537f,9255.089953f,9300.604953f,7792.2451647f,7844.5551655f,7896.8651663f,7949.1751671f,8001.4851679f,8053.7951686f,8315.3451725f,8367.6551733f,8419.9651741f,8472.2751749f,8524.585175f,8576.8951764f,8838.4451803f,8890.7551811f,8943.0651819f,8995.3751827f,9047.6851834f,9099.9951842f,9361.5451881f,9413.8551889f,9466.1651897f,9518.475190f,9570.7851912f,9623.0951920f,9884.6451959f,9936.9551967f,9989.2651975f,10041.5751982f,10093.8851990f,10146.1951998f,10407.7452037f,10460.0552045f,10512.3652053f,10564.6752060f,10616.9852068f,10669.2952076f,8787.210074f,8846.3150748f,8905.4200750f,8964.5250752f,9023.6300755f,9082.7350757f,9378.2600768f,9437.3650770f,9496.4700773f,9555.5750775f,9614.6800777f,9673.7850779f,9969.3100791f,10028.4150793f,10087.5200795f,10146.625079f,10205.7300800f,10264.8350802f,10560.3600813f,10619.465081f,10678.5700818f,10737.6750820f,10796.7800822f,10855.8850825f,11151.4100836f,11210.5150838f,11269.6200840f,11328.7250843f,11387.8300845f,11446.9350847f,11742.4600858f,11801.5650861f,11860.6700863f,11919.7750865f,11978.880086f,12037.9850870f,9782.1750935f,9848.0750935f,9913.9750934f,9979.8750934f,10045.7750934f,10111.6750933f,10441.1750931f,10507.0750931f,10572.9750931f,10638.8750930f,10704.7750930f,10770.6750930f,11100.1750928f,11166.0750927f,11231.9750927f,11297.8750927f,11363.7750926f,11429.6750926f,11759.1750924f,11825.0750924f,11890.9750923f,11956.8750923f,12022.7750923f,12088.6750922f,12418.175092f,12484.0750920f,12549.9750920f,12615.8750919f,12681.7750919f,12747.6750919f,13077.1750917f,13143.0750916f,13208.9750916f,13274.8750916f,13340.7750915f,13406.6750915f,2250.990060f,2255.7350610f,2260.4800611f,2265.2250612f,2269.9700613f,2274.7150614f,2298.4400619f,2303.185062f,2307.9300622f,2312.6750623f,2317.4200624f,2322.1650625f,2345.8900630f,2350.6350631f,2355.380063f,2360.1250634f,2364.8700635f,2369.6150636f,2393.3400641f,2398.0850642f,2402.8300643f,2407.5750644f,2412.320064f,2417.0650647f,2440.7900652f,2445.5350653f,2450.2800654f,2455.0250655f,2459.7700656f,2464.515065f,2488.2400663f,2492.9850664f,2497.7300665f,2502.4750666f,2507.2200667f,2511.9650668f,5284.4551315f,5295.9951318f,5307.535132f,5319.0751323f,5330.6151326f,5342.1551328f,5399.8551341f,5411.3951343f,5422.9351346f,5434.475134f,5446.0151351f,5457.5551354f,5515.2551366f,5526.7951369f,5538.3351371f,5549.8751374f,5561.4151376f,5572.9551379f,5630.6551392f,5642.1951394f,5653.7351397f,5665.2751399f,5676.8151402f,5688.3551404f,5746.0551417f,5757.5951420f,5769.1351422f,5780.6751425f,5792.2151427f,5803.7551430f,5861.455144f,5872.9951445f,5884.5351448f,5896.0751450f,5907.6151453f,5919.1551455f,8317.919884f,8336.2548841f,8354.5898838f,8372.9248835f,8391.2598832f,8409.59488f,8501.2698815f,8519.6048813f,8537.9398810f,8556.2748807f,8574.6098804f,8592.9448801f,8684.6198787f,8702.9548784f,8721.2898782f,8739.6248779f,8757.9598776f,8776.2948773f,8867.9698759f,8886.3048756f,8904.6398753f,8922.9748751f,8941.3098748f,8959.6448745f,9051.3198731f,9069.6548728f,9087.9898725f,9106.3248722f,9124.6598720f,9142.9948717f,9234.6698703f,9253.0048700f,9271.3398697f,9289.6748694f,9308.0098691f,9326.3448689f,11351.3852747f,11376.5152753f,11401.6452759f,11426.7752765f,11451.9052771f,11477.0352777f,11602.6852806f,11627.8152812f,11652.9452818f,11678.0752824f,11703.2052830f,11728.335283f,11853.9852865f,11879.1152871f,11904.2452877f,11929.3752883f,11954.505288f,11979.6352894f,12105.2852924f,12130.4152930f,12155.545293f,12180.6752941f,12205.8052947f,12230.9352953f,12356.5852983f,12381.715298f,12406.8452994f,12431.9753000f,12457.1053006f,12482.2353012f,12607.8853041f,12633.0153047f,12658.1453053f,12683.2753059f,12708.4053065f,12733.5353071f,14384.8499244f,14416.7749242f,14448.6999240f,14480.6249238f,14512.549923f,14544.4749235f,14704.0999225f,14736.024922f,14767.9499222f,14799.8749220f,14831.7999218f,14863.7249216f,15023.3499207f,15055.2749205f,15087.1999203f,15119.1249201f,15151.0499199f,15182.9749197f,15342.5999188f,15374.5249186f,15406.4499184f,15438.374918f,15470.2999181f,15502.2249179f,15661.84991f,15693.7749168f,15725.6999166f,15757.6249164f,15789.5499162f,15821.4749160f,15981.0999151f,16013.0249149f,16044.9499147f,16076.8749145f,16108.7999143f,16140.7249142f,17418.314976f,17457.0349761f,17495.7549760f,17534.4749759f,17573.1949758f,17611.9149757f,17805.5149753f,17844.234975f,17882.9549752f,17921.6749751f,17960.3949750f,17999.1149749f,18192.7149745f,18231.4349744f,18270.154974f,18308.8749743f,18347.5949742f,18386.3149741f,18579.9149737f,18618.6349736f,18657.3549735f,18696.074973f,18734.7949734f,18773.5149733f,18967.1149729f,19005.8349728f,19044.5549727f,19083.2749726f,19121.994972f,19160.7149725f,19354.3149721f,19393.0349720f,19431.7549719f,19470.4749718f,19509.1949717f,19547.914971f,20451.7799765f,20497.2949766f,20542.8099767f,20588.3249768f,20633.8399769f,20679.3549770f,20906.929977f,20952.4449775f,20997.9599776f,21043.4749777f,21088.9899778f,21134.5049779f,21362.0799784f,21407.5949785f,21453.1099786f,21498.624978f,21544.139978f,21589.6549788f,21817.2299793f,21862.7449794f,21908.2599795f,21953.7749796f,21999.2899797f,22044.8049798f,22272.3799802f,22317.8949803f,22363.4099804f,22408.9249805f,22454.4399806f,22499.9549807f,22727.529981f,22773.044981f,22818.5599813f,22864.0749814f,22909.5899815f,22955.1049816f,23485.2453985f,23537.555399f,23589.8654000f,23642.1754008f,23694.4854016f,23746.7954024f,24008.3454063f,24060.655407f,24112.9654078f,24165.2754086f,24217.5854094f,24269.8954102f,24531.4454141f,24583.7554148f,24636.0654156f,24688.3754164f,24740.6854172f,24792.99541f,25054.545421f,25106.8554226f,25159.1654234f,25211.4754242f,25263.7854250f,25316.0954257f,25577.6454296f,25629.9554304f,25682.2654312f,25734.5754320f,25786.8854328f,25839.1954335f,26100.7454374f,26153.0554382f,26205.3654390f,26257.6754398f,26309.985440f,26362.2954413f,26518.7101423f,26577.8151425f,26636.920142f,26696.0251430f,26755.1301432f,26814.2351434f,27109.7601446f,27168.8651448f,27227.9701450f,27287.0751452f,27346.1801455f,27405.2851457f,27700.8101468f,27759.9151470f,27819.0201473f,27878.1251475f,27937.2301477f,27996.33514f,28291.8601491f,28350.9651493f,28410.0701495f,28469.175149f,28528.2801500f,28587.3851502f,28882.9101513f,28942.0151516f,29001.1201518f,29060.2251520f,29119.3301522f,29178.4351525f,29473.9601536f,29533.0651538f,29592.1701540f,29651.2751543f,29710.3801545f,29769.4851547f,29552.1750826f,29618.0750825f,29683.9750825f,29749.8750825f,29815.7750824f,29881.6750824f,30211.1750822f,30277.0750822f,30342.9750821f,30408.8750821f,30474.7750821f,30540.6750820f,30870.175081f,30936.0750818f,31001.9750818f,31067.8750817f,31133.7750817f,31199.6750817f,31529.1750815f,31595.075081f,31660.9750814f,31726.8750814f,31792.7750813f,31858.6750813f,32188.1750811f,32254.0750811f,32319.975081f,32385.8750810f,32451.7750810f,32517.6750809f,32847.1750808f,32913.0750807f,32978.9750807f,33044.875080f,33110.7750806f,33176.67508062f}); input.linspace(1); - weightsD.linspace(1); - weightsP.linspace(1); - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); - - weightsP.applyScalar(scalar::Divide, 10000.0); nd4j::ops::sconv2d op; auto resultFF = op.execute({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); @@ -669,6 +707,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { auto result2D = op2d.execute({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); auto z2d = result2D->at(0); + // z2d->printBuffer(); ASSERT_TRUE(z2d->isSameShape(&exp2FF)); ASSERT_TRUE(z2d->equalsTo(&exp2FF)); @@ -793,7 +832,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { nd4j::ops::conv1d_bp op_bp; - auto epsilonNxt = z->dup(); + auto epsilonNxt = new NDArray(z->dup()); epsilonNxt->linspace(1); auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0}); @@ -1469,223 +1508,6 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { delete results; } -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, depthwise_conv2d_1) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, - 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); - input = 2.; - weights.linspace(0.1, 0.1); - - nd4j::ops::depthwise_conv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_2) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f}); - input = 2.; - weights.linspace(0.1, 0.1); - - nd4j::ops::depthwise_conv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - - delete results; -} - - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_3) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {mC, iC, kH, kW}); - auto biases = NDArrayFactory::create('c', {iC*mC}, {1,2,3,4}); - - - auto expOutput = NDArrayFactory::create('c', {bS, oC, oH, oW},{5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8, 5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8}); - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2,3,1,0}); - - nd4j::ops::depthwise_conv2d op; - auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results->at(0); - - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results->status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_4) { - - int bS=1, iH=111,iW=111, iC=32,mC=1, kH=7,kW=7, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=56,oW=56; - - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - const float unique = -1000000; - - NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32); - NDArray output('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32); - input.linspace(0.1, 0.0001); - weights = 0.5; - output = unique; - - nd4j::ops::depthwise_conv2d op; - Nd4jStatus status = op.execute({&input, &weights}, {&output} , {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); - - ASSERT_EQ(Status::OK(), status); - - for(Nd4jLong i=output.lengthOf()/1.5; i < output.lengthOf(); ++i) - ASSERT_EQ(output.e(i) != unique, true); -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_5) { - - int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.}); - input.linspace(1.); - weights = 1.; - - nd4j::ops::depthwise_conv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results->at(0); - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results->status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_6) { - - int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::DOUBLE); - NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::DOUBLE); - - NDArray expOutput('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.}); - input.linspace(1.); - weights = 1.; - - nd4j::ops::depthwise_conv2d op; - ResultSet* results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* output = results->at(0); - // output.printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results->status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_7) { - - int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, - 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, - 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}); - NDArray weights('c', {kH, kW, iC, mC}, {0.1308445781469345, 0.6442840099334717, 0.5698848366737366, 0.19896849989891052}); - NDArray biases('c', {1,iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}); - - NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, - 0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472, - 0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504, - 0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, - 0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}); - - - nd4j::ops::depthwise_conv2d op; - auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - - delete results; -} - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) { @@ -1695,15 +1517,15 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) { int paddingMode = 1; // 1-SAME, 0-VALID; int dataFormat = 1; // 1-NHWC, 0-NCHW - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308, - 1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028}); + NDArray expGradI('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308, + 1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028}, nd4j::DataType::FLOAT32); - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08}); + NDArray expGradW('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08}, nd4j::DataType::FLOAT32); input = 2.; weights.linspace(0.1, 0.1); @@ -1734,14 +1556,180 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test2) { int paddingMode = 0; // 1-SAME, 0-VALID; int dataFormat = 1; // 1-NHWC, 0-NCHW - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729, - 0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481}); - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88}); + NDArray expGradI('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729, + 0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481}, nd4j::DataType::FLOAT32); + NDArray expGradW('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88}, nd4j::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + nd4j::ops::depthwise_conv2d_bp op; + auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* gradI = results->at(0); + auto* gradW = results->at(1); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test3) { + + auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); + auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); + auto b = NDArrayFactory::create('c', {1, 16}); + auto grad = NDArrayFactory::create('c', {4, 16, 64, 64}); + + auto gradI = in.like(); + auto gradW = w.like(); + auto gradB = b.like(); + + nd4j:ops::depthwise_conv2d_bp op; + auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}, {}); + ASSERT_EQ(Status::OK(), status); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test4) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32); + NDArray bias('c', {oC}, nd4j::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + gradO.linspace(10, -0.1); + + + NDArray expGradI('c', {bS, iH, iW, iC},{10.880001, 13.239998, 15.520001, 17.719997, 19.840000, 21.880001, 23.839998, 25.720001, 31.360004, 34.420002, 37.360001, 40.180004, 42.880005, 45.460003, 47.919994, 50.260002, 31.360001, 33.939999, 36.400002, 38.739998, 40.959999, 43.059998, 45.040001, 46.900005, 31.359997, 33.459999, 35.439999, 37.300003, 39.040001, 40.660000, 42.160000, 43.539997, 31.360001, 32.980000, 34.480000, 35.860001, 37.119999, 38.259998, 39.279999, 40.180000, 31.360001, 32.499996, 33.520000, 34.419998, 35.200001, 35.860001, 36.400002, 36.820000, 31.360001, 32.019997, 32.560001, 32.979996, 33.280003, 33.459999, 33.520000, 33.459999, 31.360001, 31.540001, 31.599998, 31.539999, 31.360001, 31.059999, 30.639999, 30.100000, 31.360001, 31.060001, 30.639999, 30.099998, 29.440002, 28.660000, 27.759998, 26.740000, 18.559999, 18.040001, 17.440001, 16.760000, 16.000000, 15.160000, 14.240001, 13.240000, 85.439995, 85.860001, 86.159996, 86.339996, 86.400002, 86.340012, 86.159996, 85.860008, 132.000000, 131.910004, 131.639999, 131.190002, 130.559998, 129.750000, 128.760010, 127.589996, 123.360001, 122.550003, 121.559998, 120.389999, 119.040009, 117.510002, 115.799988, 113.910004, 114.720001, 113.189995, 111.480003, 109.590004, 107.520004, 105.270004, 102.839996, 100.230011, 106.079994, 103.830002, 101.400009, 98.790009, 96.000008, + 93.030006, 89.879990, 86.549988, 97.439995, 94.469994, 91.319992, 87.990005, 84.479996, 80.789993, 76.919998, 72.870003, 88.800003, 85.110001, 81.239998, 77.190002, 72.960007, 68.550003, 63.959999, 59.190002, 80.160004, 75.750000, 71.160004, 66.389999, 61.440002, 56.309994, 51.000000, 45.510002, 71.519997, 66.389999, 61.079998, 55.590000, 49.919998, 44.070000, 38.040001, 31.830002, 31.680000, 27.780003, 23.760000, 19.619999, 15.360001, 10.980000, 6.480000, 1.859999, 47.040001, 42.660004, 38.160000, 33.540001, 28.799999, 23.939999, 18.960001, 13.860001, 45.599998, 38.310001, 30.840000, 23.190002, 15.360001, 7.349998, -0.840002, -9.210003, 36.959999, 28.950003, 20.759998, 12.390001, 3.839998, -4.889999, -13.799999, -22.890003, 28.320002, 19.589998, 10.680000, 1.590002, -7.680002, -17.129999, -26.759998, -36.570007, 19.680002, 10.230003, 0.599998, -9.210001, -19.199999, -29.370003, -39.720001, -50.250008, 11.039999, 0.869999, -9.480000, -20.010002, -30.719994, -41.610001, -52.679996, -63.930008, 2.400005, -8.489998, -19.560005, -30.809998, -42.239998, -53.849991, -65.639992, -77.610001, -6.239998, -17.849998, -29.639988, -41.609985, -53.760002, -66.090004, -78.599991, -91.290009, -14.879990, -27.209995, -39.720009, -52.410007, -65.279999, -78.330002, -91.559998, -104.969986, -45.119995, -53.820000, -62.639999, -71.580002, -80.640007, -89.819992, -99.119995, -108.540009, 8.639999, -0.540001, -9.839996, -19.259998, -28.799995, -38.459999, -48.240002, -58.140003, -40.799999, -55.289997, -69.960007, -84.810013, -99.840004, -115.050011, -130.440018, -146.010010, -49.439991, -64.650009, -80.040009, -95.610016, -111.360008, -127.290001, -143.399994, -159.690018, -58.080009, -74.009987, -90.119995, -106.409988, -122.880005, -139.530014, -156.360001, -173.369995, -66.720001, -83.369995, -100.199997, + -117.209999, -134.399994, -151.769989, -169.319992, -187.049988, -75.360008, -92.729996, -110.279991, -128.009979, -145.920013, -164.009995, -182.279984, -200.729996, -84.000000, -102.089996, -120.360016, -138.809967, -157.440002, -176.249969, -195.240005, -214.410019, -92.639999, -111.449997, -130.440018, -149.610016, -168.960007, -188.489990, -208.200012, -228.090012, -101.279976, -120.809982, -140.519989, -160.410004, -180.480011, -200.730011, -221.160034, -241.770020, -121.920006, -135.420013, -149.040009, -162.779999, -176.640015, -190.619995, -204.719986, -218.940002, -29.760002, -43.739998, -57.840000, -72.059998, -86.400009, -100.860001, -115.439995, -130.140015, -127.199997, -148.890015, -170.760010, -192.809998, -215.040024, -237.450012, -260.039978, -282.809998, -135.839996, -158.250000, -180.840012, -203.610046, -226.559982, -249.690002, -272.999969, -296.489990, -144.479980, -167.609985, -190.920013, -214.410019, -238.080032, -261.929993, -285.959991, -310.169983, -153.119995, -176.969986, -201.000031, -225.210022, -249.599976, -274.170013, -298.920013, -323.849976, -161.760040, -186.330017, -211.079987, -236.009995, -261.120026, -286.410034, -311.879974, -337.530029, -170.400009, -195.689987, -221.159973, -246.809998, -272.639954, -298.650024, -324.840057, -351.209991, -179.039963, -205.050018, -231.240021, -257.609985, -284.160004, -310.890015, -337.799988, -364.890015, -187.680023, -214.410004, -241.319977, -268.410004, -295.679993, -323.130005, -350.760010, -378.570038, -198.720016, -217.019989, -235.440002, -253.979980, -272.640045, -291.419983, -310.319977, -329.339996, -68.159981, -86.939987, -105.840012, -124.860001, -144.000000, -163.260010, -182.639984, -202.140015, -213.600021, -242.489990, -271.559937, -300.809998, -330.239990, -359.849976, -389.639984, + -419.610016, -222.240036, -251.849960, -281.640015, -311.609985, -341.760040, -372.089996, -402.600037, -433.290009, -230.880005, -261.210022, -291.719971, -322.410034, -353.280029, -384.329956, -415.559998, -446.970001, -239.519989, -270.570007, -301.800018, -333.209991, -364.800018, -396.570007, -428.520020, -460.650024, -248.160034, -279.929962, -311.880005, -344.010010, -376.320038, -408.809998, -441.479980, -474.330017, -256.799988, -289.289978, -321.960022, -354.809967, -387.839996, -421.050018, -454.440002, -488.009979, -265.440002, -298.650024, -332.040009, -365.609985, -399.360016, -433.290009, -467.399963, -501.689941, -274.080017, -308.009949, -342.119995, -376.409973, -410.880005, -445.530029, -480.359985, -515.369995, -275.520020, -298.619995, -321.839966, -345.179993, -368.640015, -392.220001, -415.919952, -439.740021, -106.560005, -130.140030, -153.840027, -177.659973, -201.599991, -225.660019, -249.840012, -274.140015, -300.000000, -336.090057, -372.360046, -408.809937, -445.440002, -482.250031, -519.240051, -556.410034, -308.640015, -345.450012, -382.440002, -419.609955, -456.959961, -494.489960, -532.200012, -570.089966, -317.280029, -354.809998, -392.520020, -430.410004, -468.480042, -506.729980, -545.159912, -583.770020, -325.920013, -364.169952, -402.600037, -441.210022, -480.000000, -518.970032, -558.119873, -597.449951, -334.559967, -373.529999, -412.679993, -452.009949, -491.519989, -531.209961, -571.080017, -611.129944, -343.200012, -382.889984, -422.760071, -462.809906, -503.039978, -543.449951, -584.039978, -624.809998, -351.839966, -392.250000, -432.839966, -473.609955, -514.560120, -555.689941, -596.999939, -638.489990, -360.480011, -401.610016, -442.920044, -484.409912, -526.080017, -567.929993, -609.959961, -652.169983, -352.320007, -380.220001, + -408.239990, -436.380005, -464.639984, -493.019989, -521.519958, -550.139954, -144.960022, -173.339996, -201.839996, -230.459976, -259.200043, -288.059998, -317.039978, -346.140015, -386.399963, -429.690002, -473.159912, -516.809937, -560.640076, -604.650024, -648.839966, -693.210022, -395.039978, -439.050018, -483.239929, -527.609985, -572.159973, -616.890015, -661.799988, -706.890015, -403.680023, -448.409973, -493.320007, -538.410034, -583.680054, -629.129944, -674.760010, -720.570068, -412.320007, -457.769897, -503.399963, -549.210083, -595.199951, -641.369995, -687.720093, -734.250000, -420.960052, -467.130035, -513.479980, -560.010010, -606.720093, -653.610046, -700.680054, -747.930115, -429.599976, -476.489990, -523.559998, -570.809937, -618.239990, -665.849976, -713.640015, -761.609985, -438.239990, -485.850037, -533.640015, -581.610046, -629.760010, -678.089966, -726.600037, -775.289917, -446.880035,-495.210052, -543.719971, -592.410034, -641.279968, -690.330017, -739.559937, -788.970093, -429.120026, -461.819946, -494.639984, -527.580017, -560.640015, -593.820007, -627.119995, -660.540039, -183.360016, -216.540009, -249.839996, -283.260040, -316.800018, -350.459961, -384.239990, -418.139984, -472.800049, -523.289917, -573.959961, -624.809998, -675.839966, -727.050049, -778.440063, -830.010010, -481.440002, -532.649963, -584.040100, -635.609985, -687.359924, -739.290039, -791.399963, -843.689941, -490.079987, -542.010010, -594.119995, -646.410034, -698.880005, -751.529968, -804.359985, -857.369995, -498.720032, -551.369995, -604.200012, -657.210022, -710.400024, -763.770081, -817.319946, -871.050049, -507.359955, -560.729919, -614.280029, -668.010010, -721.919983, -776.010010, -830.280029, -884.730042, -515.999939, -570.089966, -624.360046, -678.809937, -733.440002, + -788.250000, -843.239990, -898.410034, -524.639954, -579.449951, -634.440002, -689.609985, -744.960022, -800.489990, -856.200012, -912.090027, -533.280029, -588.810059, -644.520081, -700.409973, -756.480042, -812.730103, -869.159912, -925.769958, -505.920013, -543.420044, -581.040039, -618.780029, -656.640015, -694.620056, -732.719971, -770.940002, -447.359985, -471.559998, -495.840027, -520.200012, -544.640015, -569.159973, -593.760010, -618.440002, -815.359985, -852.140015, -889.040039, -926.059937, -963.200073, -1000.460022, -1037.839966, -1075.339966, -826.879944, -864.139954, -901.519958, -939.019958, -976.640076, -1014.379944, -1052.239990, -1090.219971, -838.400024, -876.140015, -913.999939, -951.979919, -990.080017, -1028.299927, -1066.640015, -1105.099976, -849.919983, -888.140015, -926.479980, -964.939941, -1003.520081, -1042.219971, -1081.040039, -1119.979980, -861.440063, -900.140015, -938.960022,-977.899963, -1016.960022, -1056.140015, -1095.440063, -1134.859985, -872.960022, -912.140015, -951.439941, -990.859985, -1030.400024, -1070.060059, -1109.839844, -1149.739990, -884.479980, -924.140015, -963.919922, -1003.819946, -1043.839966, -1083.979980, -1124.239990, -1164.619995, -896.000000, -936.140015, -976.399963, -1016.780029, -1057.280029, -1097.899902, -1138.640015, -1179.500122, -705.919983, -733.000000, -760.159912, -787.400024, -814.719971, -842.119995, -869.599976, -897.160034}, nd4j::DataType::FLOAT32); + + NDArray expGradW('c', {kH, kW, iC, mC},{-104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, + -107702.734375, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104824.789062, + -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -126744.000000, -127277.710938, -127813.187500, + -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -140944.000000, -141536.984375, -142131.984375, -142729.000000, -143328.000000, + -143929.015625, -144532.000000, -145137.000000, -126744.000000, -127277.710938, -127813.187500, -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -104824.789062, -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, -107702.734375}, nd4j::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, nd4j::DataType::FLOAT32); + + nd4j::ops::depthwise_conv2d_bp op; + ResultSet* results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + NDArray* gradI = results->at(0); + NDArray* gradW = results->at(1); + NDArray* gradB = results->at(2); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test5) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, nd4j::DataType::FLOAT32); + NDArray bias('c', {oC}, nd4j::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + gradO.linspace(10, -0.1); + + + NDArray expGradI('c', {bS, iC, iH, iW}, {-12.639999, 3.920004, 3.920000, 3.920000, 3.920002, 3.920000, 3.920000, 3.919998, 3.919998, 16.319998, 52.680004, 111.000015, 109.919991, 108.840004, 107.760002, 106.680008, 105.600006, 104.519997, 103.440018, 87.960007, 47.880001, 100.200005, 99.119995, 98.040001, 96.959999, 95.879990, 94.799995, 93.720001, 92.639999, 78.360001, 43.079998, 89.399994, 88.320007, 87.240005, 86.159996, 85.079994, 84.000000, 82.919998, 81.840004, 68.759995, 38.279999, 78.600006, 77.519997, 76.440010, 75.360001, 74.279999, 73.200005, 72.120003, 71.040001, 59.160004, 33.480000, 67.799995, 66.720009, 65.639999, 64.559998, 63.480000, 62.399994, 61.320007, 60.240002, 49.559998, 28.680004, 57.000004, 55.919998, 54.839993, 53.759998, 52.680000, 51.600002, 50.519997, 49.440002, 39.959999, 23.880001, 46.200001, 45.120003, 44.039997, 42.959999, 41.880001, 40.799999, 39.719994, 38.639999, 30.360001, 19.079998, 35.400002, 34.320000, 33.239998, 32.159996, 31.080000, 29.999998, 28.919998, 27.840000, 20.759998, 14.079999, 24.080000, 22.639997, 21.200001, 19.759998, 18.320002, 16.880001, 15.440001, 14.000000, 9.759999, 3.140000, 3.560000, 3.500000, 3.440000, 3.380000, 3.320000, 3.260000, 3.200000, 3.140000, -0.220000, 4.050000, 2.010000, 0.840000, -0.330000, -1.499999, -2.670000, -3.840000, -5.010000, -6.179998, -9.150000, -1.350000, -9.690001, -10.859999, -12.029998, -13.200001, -14.370001, -15.539999, -16.710001, -17.879999, -19.349998, -6.750000, -21.389997, -22.560003, -23.730003, -24.900002, -26.069998, -27.239998, -28.410007, -29.580002, -29.550003, -12.150001, -33.089996, -34.260002, -35.430000, -36.600002, -37.770000, -38.939995, -40.110001, -41.280003, -39.749996, -17.550003, -44.790005, -45.959991, -47.129993, -48.300003, -49.470001, -50.640003, -51.809990, -52.979996, -49.950001, -22.949999, -56.490005, -57.660000, -58.829998, -60.000000, -61.170002, -62.340004, -63.510002, -64.680000, + -60.149994, -28.349998, -68.189987, -69.360001, -70.529999, -71.700005, -72.870010, -74.039993, -75.209999, -76.379990, -70.349998, -33.749996, -79.889999, -81.059990, -82.229988, -83.399994, -84.570007, -85.740005, -86.910004, -88.079994, -80.549995, -69.340004, -125.080002, -126.580002, -128.080002, -129.580002, -131.080002, -132.580002, -134.080002, -135.580002, -105.979996, 10.919998, -8.799997, -8.919998, -9.040003, -9.160004, -9.279999, -9.400002, -9.520002, -9.640003, -24.760000, -56.580009, -124.980003, -126.240005, -127.499992, -128.759995, -130.020020, -131.279999, -132.540009, -133.800003, -118.260002, -62.580009, -137.580002, -138.840012, -140.099991, -141.360001, -142.620010, -143.879974, -145.139999, -146.399994, -129.060013, -68.580002, -150.179993, -151.439987, -152.699997, -153.959991, -155.219986, -156.480011, -157.740005, -159.000000, -139.860001, -74.579994, -162.779999, -164.040024, -165.300003, -166.560028, -167.819977, -169.080002, -170.339996, -171.599991, -150.660004, -80.580002, -175.379990, -176.639999, -177.899994, -179.160019, -180.419998, -181.679993, -182.940002, -184.199997, -161.459991, -86.580002, -187.979996, -189.240005, -190.499985, -191.759995, -193.020020, -194.279999, -195.540024, -196.800018, -172.260010, -92.580002, -200.579987, -201.839981, -203.100006, -204.359970, -205.620010, -206.880005, -208.139999, -209.399994, -183.060013, -98.580002, -213.180023, -214.440002, -215.700012, -216.959991, -218.220001, -219.480011, -220.739975, -222.000000, -193.860001, -160.760010, -286.239990, -287.799988, -289.360016, -290.920013, -292.480011, -294.040009, -295.599976, -297.160004, -229.719986, 10.700003, -33.160004, -33.339996, -33.519993, -33.700001, + -33.879997, -34.059994, -34.239994, -34.419994, -57.299995, -129.209991, -269.969971, -271.319977, -272.670044, -274.019989, -275.369995, -276.720001, -278.070007, -279.420013, -239.369980, -135.809998, -283.470001, -284.820007, -286.169983, -287.520020, -288.869995, -290.220001, -291.570038, -292.919983, -250.770004, -142.410004, -296.969971, -298.320007, -299.669983, -301.020020, -302.369995, -303.719971, -305.070007, -306.419983, -262.169983, -149.009995, -310.470001, -311.820007, -313.170013, -314.519989, -315.869995, -317.220001, -318.570007, -319.919983, -273.570007, -155.610016, -323.969971, -325.320038, -326.669983, -328.020020, -329.369965, -330.719971, -332.070007, -333.419983, -284.970001, -162.209991, -337.469971, -338.820007, -340.169983, -341.519958, -342.869995, -344.220001, -345.570007, -346.920013, -296.369995, -168.809998, -350.970001, -352.320007, -353.669983, -355.019989, -356.369995, -357.719971, -359.070038, -360.419983, -307.769989, -175.410004, -364.469971, -365.820007, -367.169983, -368.520020, -369.869995, -371.219971, -372.570007, -373.919983, -319.169983, -260.179993, -459.399994, -461.019958, -462.639984, -464.260010, -465.880005, -467.500000, -469.119995, -470.739990, -361.459991, 2.480003, -69.520004, -69.760025, -70.000000, -70.239990, -70.479996, -70.720001, -70.960007, -71.200005, -97.839996, -213.840012, -432.960022, -434.400055, -435.840027, -437.279999, -438.720001, -440.160065, -441.599976, -443.040039, -372.480011, -221.040009, -447.360016, -448.800018, -450.239990, -451.679993, -453.119995, -454.559967, -456.000061, -457.440033, -384.480011, -228.239990, -461.759979, -463.200012, -464.639984, -466.079956, -467.520081, -468.960052, -470.399963, -471.839996, -396.479980, -235.440002, -476.159912, + -477.600006, -479.040039, -480.479980, -481.919952, -483.360046, -484.800079, -486.239990, -408.480042, -242.639999, -490.559967, -491.999969, -493.440063, -494.880035, -496.319946, -497.759979, -499.200012, -500.639984, -420.480011, -249.840012, -504.960052, -506.399963, -507.839996, -509.280029, -510.720001, -512.159973, -513.599976, -515.040039, -432.480011, -257.040009, -519.360046, -520.800049, -522.239990, -523.680054, -525.120056, -526.559998, -527.999939, -529.440002, -444.480011, -264.239990, -533.760010, -535.200012, -536.640015, -538.079956, -539.520020, -540.960022, -542.399963, -543.839966, -456.479980, -367.599976, -644.559998, -646.239929, -647.920044, -649.599976, -651.280029, -652.960022, -654.640076, -656.320007, -501.200043, -13.740002, -117.880005, -118.179993, -118.479996, -118.780014, -119.080002, -119.379990, -119.680008, -119.979996, -146.379990, -310.470001, -613.950012, -615.479980, -617.010071, -618.539978, -620.069946, -621.599976, -623.130005, -624.660034, -517.589966, -318.269958, -629.250000, -630.779968, -632.309937, -633.840027, -635.369995, -636.899902, -638.429993, -639.959961, -530.190063, -326.070038, -644.550049, -646.079956, -647.609985, -649.140015, -650.669922, -652.200012, -653.729980, -655.260010, -542.789978, -333.870026, -659.849976, -661.380005, -662.910034, -664.439941, -665.970093, -667.500000, -669.029968, -670.559937, -555.390015, -341.669983, -675.149902, -676.679993, -678.209961, -679.740051, -681.270020, -682.800049, -684.329956, -685.859985, -567.989990, -349.470001, -690.450012, -691.979980, -693.510010, -695.039978, -696.569946, -698.099976, -699.630005, -701.160034, -580.589966, -357.269958, -705.750000, -707.279968, -708.809937, -710.340027, -711.869995, -713.399902, -714.929993, -716.459961, -593.190002, -365.070038, -721.050049, -722.579956, -724.109985, -725.640015, -727.169922, -728.700012, + -730.229980, -731.760010, -605.789978, -483.019958, -841.719971, -843.460022, -845.200073, -846.939941, -848.680054, -850.419983, -852.159973, -853.899963, -648.940002, -37.960014, -178.240021, -178.599976, -178.959991, -179.320007, -179.679993, -180.039978, -180.399994, -180.759964, -202.919983, -419.099915, -812.939941, -814.559937, -816.179993, -817.800049, -819.419922, -821.040039, -822.660034, -824.279968, -674.699951, -427.500031, -829.140015, -830.759949, -832.380005, -833.999939, -835.619995, -837.240051, -838.859924, -840.479980, -687.899963, -435.899994, -845.339966, -846.959961, -848.579956, -850.200012, -851.819885, -853.439941, -855.059937, -856.679993, -701.100037, -444.299927, -861.540039, -863.160034, -864.779968, -866.399963, -868.020020, -869.640015, -871.259949, -872.880005, -714.299988, -452.700012, -877.740051, -879.359924, -880.979980, -882.599915, -884.219971, -885.839966, -887.459961, -889.079956, -727.500000, -461.099915, -893.939941, -895.559937, -897.179993, -898.800049, -900.419922, -902.040039, -903.660034, -905.279968, -740.700012, -469.499969, -910.140015, -911.759949, -913.380005, -914.999939, -916.620056, -918.239990, -919.860046, -921.479919, -753.899963, -477.899902, -926.339905, -927.959961, -929.579956, -931.200012, -932.819946, -934.439880, -936.059937, -937.679932, -767.100037, -606.439941, -1050.880005, -1052.680054, -1054.479980, -1056.280029, -1058.079956, -1059.880005, -1061.679932, -1063.479980, -804.679993, -70.180008, -250.600006, -251.019958, -251.440033, -251.860001, -252.280029, -252.700043, -253.120026, -253.540039, -267.459991, -539.730042, -1029.929932, -1031.640137, -1033.350098, -1035.060059, -1036.770020, -1038.479980, -1040.190063, -1041.900024, -843.809998, -548.729980, -1047.030029, -1048.740112, -1050.449829, -1052.160034, -1053.870117, -1055.580078, -1057.289917, -1059.000122, -857.609985, -557.729980, + -1064.130005, -1065.840088, -1067.550049, -1069.260010, -1070.969849, -1072.679932, -1074.390137, -1076.100098, -871.410034, -566.729980, -1081.229980, -1082.940063, -1084.650024, -1086.359985, -1088.069946, -1089.780029, -1091.489990, -1093.199951, -885.210022, -575.729980, -1098.329956, -1100.040039, -1101.750122, -1103.460205, -1105.170166, -1106.879883, -1108.589966, -1110.300049, -899.010071, -584.730042, -1115.429932, -1117.140137, -1118.850098, -1120.560059, -1122.270020, -1123.979980, -1125.689941, -1127.400024, -912.810059, -593.730042, -1132.530029, -1134.240234, -1135.949951, -1137.659912, -1139.370117, -1141.079956, -1142.790039, -1144.500122, -926.610046, -602.730042, -1149.629883, -1151.339966, -1153.050049, -1154.760132, -1156.469971, -1158.179810, -1159.890137, -1161.600098, -940.410034, -737.859985, -1272.040039, -1273.899902, -1275.760010, -1277.619995, -1279.479980, -1281.340088, -1283.200195, -1285.060059, -968.420044}, nd4j::DataType::FLOAT32); + + NDArray expGradW('c', {kH, kW, iC, mC}, {-2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000, + -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2594.701416, -2513.699951, + -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -3043.501465, -2953.500244, -20863.500000, -56773.492188, + -110683.515625, -182593.515625, -272503.531250, -380413.562500, -3383.499756, -3283.500000, -23183.501953, -63083.500000, -122983.500000, -202883.515625, + -302783.531250, -422683.468750, -3043.501465, -2953.500244, -20863.500000, -56773.492188, -110683.515625, -182593.515625, -272503.531250, -380413.562500, + -2594.701416, -2513.699951, -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000}, nd4j::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, nd4j::DataType::FLOAT32); + + nd4j::ops::depthwise_conv2d_bp op; + ResultSet* results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + NDArray* gradI = results->at(0); + NDArray* gradW = results->at(1); + NDArray* gradB = results->at(2); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test6) { + + int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int oC=iC*mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {3,4}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01, + 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136, + 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}); + + auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC}, {1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68}); input = 2.; weights.linspace(0.1, 0.1); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index de3cdcdba..a16d9cfbd 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -110,13 +110,13 @@ TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) { auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f}); input = 0.5; weights.linspace(0.1, 0.1); @@ -240,10 +240,10 @@ TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) { weightsD.permutei({2,3,1,0}); weightsP.permutei({2,3,1,0}); - input.applyScalar(scalar::Divide, 100.0); - weightsD.applyScalar(scalar::Divide, 100.0); - weightsP.applyScalar(scalar::Divide, 100.0); - epsilonNext.applyScalar(scalar::Divide, 100.0); + input.applyScalar(scalar::Divide, 100.0, input); + weightsD.applyScalar(scalar::Divide, 100.0, weightsD); + weightsP.applyScalar(scalar::Divide, 100.0, weightsP); + epsilonNext.applyScalar(scalar::Divide, 100.0, epsilonNext); nd4j::ops::sconv2d_bp op; auto resultBP = op.execute({&input, &epsilonNext, &weightsD, &weightsP },{}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); @@ -1132,11 +1132,11 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test2) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f, - 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, - 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f, - 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f, - 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, + 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f, + 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f, + 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, 191.5f, 192.5f, 193.5f, 194.5f, 195.5f, 196.5f, 196.f, 197.f, 198.f, 200.5f, 201.5f, 202.5f, 203.5f, 204.5f, 205.5f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 208.f, 209.f, 210.f, 209.5f, 210.5f, 211.5f}); input.linspace(1.); @@ -1160,8 +1160,8 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test3) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, - 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, + 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f}); input.linspace(1.); @@ -1185,23 +1185,23 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test4) { int dataFormat = 0; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, 6.50f, - 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f, - 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f, - 10.50f, 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, 50.50f, 25.50f, 16.833334f, - 34.00f, 34.666668f, 17.50f, 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f, - 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, 19.00f, 38.25f, 38.75f, 19.50f, - 19.75f, 39.75f, 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, 28.833334f, 58.00f, - 58.666668f, 29.50f, 30.833334f, 62.00f, 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f, - 16.75f, 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, 56.75f, 28.50f, 28.75f, - 57.75f, 58.25f, 29.25f, 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f, - 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, 22.416666f, 45.00f, - 45.333332f, 22.75f, 34.00f, 68.25f, 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f, - 37.50f, 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, 158.50f, 79.50f, - 52.833332f, 106.00f, 106.666664f, 53.50f, 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, 56.833332f, 114.00f, 114.666664f, - 57.50f, 28.416666f, 57.00f, 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f, - 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, 191.50f, 96.00f, 96.50f, - 193.50f, 194.50f, 97.50f, 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, 102.50f, 205.50f, 206.50f, 103.50f, + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, 6.50f, + 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f, + 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f, + 10.50f, 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, 50.50f, 25.50f, 16.833334f, + 34.00f, 34.666668f, 17.50f, 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f, + 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, 19.00f, 38.25f, 38.75f, 19.50f, + 19.75f, 39.75f, 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, 28.833334f, 58.00f, + 58.666668f, 29.50f, 30.833334f, 62.00f, 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f, + 16.75f, 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, 56.75f, 28.50f, 28.75f, + 57.75f, 58.25f, 29.25f, 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f, + 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, 22.416666f, 45.00f, + 45.333332f, 22.75f, 34.00f, 68.25f, 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f, + 37.50f, 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, 158.50f, 79.50f, + 52.833332f, 106.00f, 106.666664f, 53.50f, 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, 56.833332f, 114.00f, 114.666664f, + 57.50f, 28.416666f, 57.00f, 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f, + 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, 191.50f, 96.00f, 96.50f, + 193.50f, 194.50f, 97.50f, 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, 102.50f, 205.50f, 206.50f, 103.50f, 68.833336f, 138.00f, 138.666672f, 69.50f, 34.416668f, 69.00f, 69.333336f, 34.75f, 52.00f, 104.25f, 104.75f, 52.50f, 52.75f, 105.75f, 106.25f, 53.25f, 35.416668f, 71.00f, 71.333336f, 35.75f}); input.linspace(1.); @@ -1225,7 +1225,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test1) { int dataFormat = 0; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, 104.f, 105.f, 107.f, 108.f, + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, 104.f, 105.f, 107.f, 108.f, 128.f, 129.f, 131.f, 132.f, 140.f, 141.f, 143.f, 144.f, 164.f, 165.f, 167.f, 168.f, 176.f, 177.f, 179.f, 180.f, 200.f, 201.f, 203.f, 204.f, 212.f, 213.f, 215.f, 216.f}); input.linspace(1.); @@ -1249,11 +1249,11 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test2) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, - 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, - 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, - 157.f, 158.f, 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, - 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, + 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, + 157.f, 158.f, 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, + 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f}); input.linspace(1.); @@ -1277,7 +1277,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test3) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f}); input.linspace(1.); @@ -1301,13 +1301,13 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test4) { int dataFormat = 0; // -NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{ 4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, - 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, 60.f, 60.f, 58.f, 59.f, 60.f, 60.f, - 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, 82.f, 83.f, 84.f, 84.f, - 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, - 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, - 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f, - 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, 192.f, 192.f, 190.f, 191.f, 192.f, 192.f, + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{ 4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, + 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, 60.f, 60.f, 58.f, 59.f, 60.f, 60.f, + 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, 82.f, 83.f, 84.f, 84.f, + 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, + 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, + 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f, + 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, 192.f, 192.f, 190.f, 191.f, 192.f, 192.f, 196.f, 197.f, 198.f, 198.f, 199.f, 200.f, 201.f, 201.f, 202.f, 203.f, 204.f, 204.f, 202.f, 203.f, 204.f, 204.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f}); input.linspace(1.); @@ -1332,14 +1332,14 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test1) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f}); input.linspace(1.); gradO = 2.; @@ -1366,14 +1366,14 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f}); input.linspace(1.); gradO = 2.; @@ -1402,13 +1402,13 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test3) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, - 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f}); input.linspace(1.); gradO = 2.; @@ -1434,13 +1434,13 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test4) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, - 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f, - 0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, + 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f, + 0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f}); input.linspace(1.); gradO = 2.; @@ -1466,11 +1466,11 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1496,14 +1496,14 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f, + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.370e+01f, 3.380e+01f, 6.790e+01f, 3.410e+01f, 3.420e+01f, 6.870e+01f, 6.940e+01f, 6.960e+01f, 1.398e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 7.220e+01f, 7.240e+01f, 1.454e+02f, 7.300e+01f, 7.320e+01f, 1.470e+02f, 1.484e+02f, 1.488e+02f, 2.988e+02f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1529,13 +1529,13 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test3) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, - 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, - 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, + 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, + 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 34.4f, 34.6f, 34.8f, 70.6f, 71.f, 71.4f, 0.f, 0.f, 0.f, 74.2f, 74.6f, 75.f, 152.f, 152.8f, 153.6f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1562,12 +1562,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test4) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f, + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f, 14.1f, 14.7f, 15.3f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, - 24.6f, 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 24.6f, 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 108.6f, 109.2f, 109.8f, 222.6f, 223.8f, 225.f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1651,8 +1651,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_3) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1678,8 +1678,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_4) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, 5.6f, 11.8f, - 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, 11.1f, 11.8f, 12.f, 24.6f, + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, 5.6f, 11.8f, + 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, 11.1f, 11.8f, 12.f, 24.6f, 0.f, 0.f, 0.f, 6.5f, 6.6f, 13.5f, 6.9f, 7.f, 14.3f, 15.f, 15.2f, 31.f, 0.f, 0.f, 0.f, 8.1f, 8.2f, 16.7f, 8.5f, 8.6f, 17.5f, 18.2f, 18.4f, 37.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1705,8 +1705,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_5) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, - 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f, + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, + 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f, 0.f, 0.f, 0.f, 4.6f, 4.7f, 4.8f, 10.1f, 10.3f, 10.5f, 0.f, 0.f, 0.f, 11.9f, 12.1f, 12.3f, 25.6f, 26.f, 26.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1732,8 +1732,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_6) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, - 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, + 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 0.f, 0.f, 0.f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1841,11 +1841,11 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_3) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f, - 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f, - 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f, - 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f, - 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f, + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f, + 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f, + 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f, + 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f, + 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f, 0.35f, 0.716667f, 0.366667f, 0.733333f, 1.5f, 0.766667f, 0.733333f, 1.5f, 0.766667f, 0.383333f, 0.783333f, 0.4f }); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1872,11 +1872,11 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_4) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f, - 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f, - 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f, - 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f, - 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f, + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f, + 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f, + 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f, + 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f, + 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f, 5.566667f, 5.633333f, 5.7f, 8.549999f, 8.65f, 8.75f, 8.95f, 9.05f, 9.150001f, 6.1f, 6.166667f, 6.233334f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1903,9 +1903,9 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_5) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, 1.425f, 2.4f, 2.575f, 2.75f, - 1.18333f, 1.24167f, 1.3f, 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, 3.925f, - 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f, + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, 1.425f, 2.4f, 2.575f, 2.75f, + 1.18333f, 1.24167f, 1.3f, 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, 3.925f, + 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f, 3.28333f, 3.34167f, 3.4f, 6.74167f, 6.85833f, 6.975f, 10.725f, 10.9f, 11.075f, 2.51667f, 2.55833f, 2.6f, 5.15833f, 5.24167f, 5.325f, 8.175f, 8.3f, 8.425f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1932,9 +1932,9 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_6) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, - 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, 0.2f, - 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, + 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, 0.2f, + 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.31667f, 0.33333f, 0.35f, 0.68333f, 0.71667f, 0.75f, 0.36667f, 0.38333f, 0.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1994,11 +1994,11 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_2) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f, - 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f, - 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f, - 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f, - 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f, + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f, + 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f, + 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f, + 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f, + 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f, 5.671396e-01f, 1.181128e+00f, 6.145977e-01f, 1.248783e+00f, 2.595083e+00f, 1.347494e+00f, 1.368600e+00f, 2.840157e+00f, 1.472778e+00f, 7.470673e-01f, 1.547362e+00f, 8.008900e-01f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -2029,9 +2029,9 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) { auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.007931f, 0.042891f, 0.040544f, 0.09369f, 0.276841f, 0.191675f, 0.163957f, 0.442946f, 0.287512f, 0.154919f, 0.373153f, 0.221172f, 0.15901f, 0.365232f, 0.207846f, 0.428282f, 0.959455f, 0.534076f, 0.508585f, 1.128771f, 0.623089f, 0.319794f, 0.698063f, 0.379547f, - 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f, - 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f, - 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f, + 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f, + 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f, + 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f, 0.810192f, 1.673009f, 0.863237f, 1.739711f, 3.58665f, 1.847753f, 1.82126f, 3.752188f, 1.931741f, 0.973081f, 2.001861f, 1.029173f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -2128,4 +2128,297 @@ TEST_F(ConvolutionTests2, upsampling2d_bp_3) { delete results; } + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, depthwise_conv2d_1) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, + 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); + input = 2.; + weights.linspace(0.1, 0.1); + + nd4j::ops::depthwise_conv2d op; + auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_2) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f}); + input = 2.; + weights.linspace(0.1, 0.1); + + nd4j::ops::depthwise_conv2d op; + auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_3) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {mC, iC, kH, kW}); + auto biases = NDArrayFactory::create('c', {iC*mC}, {1.f,2.f,3.f,4.f}); + + NDArray expOutput('c', {bS, oC, oH, oW},{5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8, 5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8}, nd4j::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2,3,1,0}); + + nd4j::ops::depthwise_conv2d op; + auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_4) { + + int bS=1, iH=111,iW=111, iC=32,mC=1, kH=7,kW=7, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=56,oW=56; + + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + const float unique = -1000000; + + NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32); + NDArray output('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32); + input.linspace(0.1, 0.0001); + weights = 0.5; + output = unique; + + nd4j::ops::depthwise_conv2d op; + Nd4jStatus status = op.execute({&input, &weights}, {&output} , {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); + + ASSERT_EQ(Status::OK(), status); + + for(Nd4jLong i=output.lengthOf()/1.5; i < output.lengthOf(); ++i) + ASSERT_EQ(output.e(i) != unique, true); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_5) { + + int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=3,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + NDArray expOutput('c', {bS, oH, oW, oC}, {10., 12., 14., 16., 8., 9., 22., 24., 26., 28., 14., 15., 14., 15., 16., 17., 8.5, 9.}, nd4j::DataType::FLOAT32); + + input.linspace(1.); + weights = 0.5; + + nd4j::ops::depthwise_conv2d op; + auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_6) { + + int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=3,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.}, nd4j::DataType::FLOAT32); + input.linspace(1.); + weights = 1.; + + nd4j::ops::depthwise_conv2d op; + ResultSet* results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + NDArray* output = results->at(0); + // output.printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_7) { + + int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, + 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, + 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, {0.1308445781469345, 0.6442840099334717, 0.5698848366737366, 0.19896849989891052}, nd4j::DataType::FLOAT32); + NDArray biases('c', {1,iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}, nd4j::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, + 0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472, + 0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504, + 0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, + 0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, nd4j::DataType::FLOAT32); + + + nd4j::ops::depthwise_conv2d op; + auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_8) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oH, oW, oC}, {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, + -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, 22.059998, + 25.600000, 26.860001, 28.239998, 29.739998, 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, 63.900002, 64.559998, + 65.340004, 106.080002, 106.169998, 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, + 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987, + 192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990, + 214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988, + 288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, + 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, + 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, nd4j::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + nd4j::ops::depthwise_conv2d op; + ResultSet* results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + NDArray* output = results->at(0); + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_9) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oC, oH, oW}, {-103.360001, -131.440002, -130.000000, -128.559998, -127.120003, -125.680000, -124.240005, -122.799995, -121.360001, -66.720001,-76.199997, -81.239998, -80.160004, -79.080002, -78.000000, -76.919998, -75.840004, -74.760002, -73.680000, -29.400002, -66.599998, -70.440002, -69.360001, -68.279999, + -67.199997, -66.120003, -65.040001, -63.959999, -62.879997, -24.599997, -57.000000, -59.639999, -58.560005, -57.479996, -56.399998, -55.320000, -54.240002, -53.159996, -52.080002, -19.799997, -47.400002, -48.840000, -47.760002, -46.680000, -45.599998, -44.520000, -43.440002, -42.360001, -41.279999, -15.000000, -37.799999, -38.040001, + -36.959999, -35.879997, -34.799999, -33.720001, -32.639999, -31.560001, -30.479996, -10.199999, -28.200001, -27.240002, -26.160000, -25.080002, -24.000000, -22.919998,-21.840000, -20.759998, -19.679998, -5.400000, -18.599998, -16.439999, -15.360001, -14.280001, -13.200001, -12.120001, -11.040000, -9.960001, -8.880000, -0.600000, + -9.000000, -5.639999, -4.560000, -3.480000, -2.400000, -1.320001, -0.240000, 0.840001, 1.920000, 4.200000, 0.160000, 3.920000, 3.920000, 3.920000, 3.920000, 3.920000,3.920001, 3.920000, 3.920000, 3.520000, 8.860001, 12.920000, 14.420000, 15.920000, 17.420000, 18.920000, 20.420000, 21.920000, 23.420000, 13.820000, 20.430000, 27.750000, + 28.919998, 30.090000, 31.260000, 32.430000, 33.600002, 34.770000, 35.939999, 19.709999, 30.630001, 39.450001, 40.619999, 41.790001, 42.960003, 44.129997, 45.299999, 46.470001, 47.639999, 25.110001, 40.829998, 51.150002, 52.320000, 53.489998, 54.660004, 55.829994, 57.000000, 58.169998, 59.340004, 30.510002, 51.029999, 62.849998, + 64.019997, 65.190002, 66.360001, 67.529999, 68.699997, 69.870003, 71.040001, 35.910000, 61.229996, 74.550003, 75.720001, 76.889999, 78.059998, 79.229996, 80.400002, 81.570000, 82.740005, 41.310001, 71.430000, 86.250000, 87.419998, 88.589996, 89.760002, 90.929993, 92.099991, 93.270004, 94.440002, 46.709999, 81.630005, 97.949997, + 99.120003, 100.290009, 101.459999, 102.630005, 103.800003, 104.970001, 106.139999, 52.110001, 91.830002, 109.649994, 110.820007, 111.990005, 113.159996, 114.330002, 115.500000, 116.669998, 117.839996, 57.509995, 19.580000, 9.079998, 9.139999, 9.199999, 9.259996, 9.320001, 9.379998, 9.440000, 9.500000, -8.740000, 129.080002, 169.279999, + 170.839996, 172.399994, 173.960007, 175.520004, 177.080002, 178.639999, 180.199982, 102.360001, 129.059998, 154.739990, 156.000000, 157.259995, 158.520004, 159.779999, 161.039993, 162.300003, 163.559998, 80.820000, 139.860001, 167.340012, 168.600006, 169.860001, 171.119995, 172.380005, 173.639999, 174.899994, 176.160004, 86.820000, + 150.660004, 179.940002, 181.200012, 182.459991, 183.720001, 184.980011, 186.239990, 187.500000, 188.759995, 92.820007, 161.459991, 192.540009, 193.799988, 195.059998, 196.319992, 197.579987, 198.839996, 200.100006, 201.360001, 98.820000, 172.259995, 205.139999, 206.399994, 207.660004, 208.919983, 210.179993, 211.440002, 212.700012, + 213.959991, 104.819992, 183.059998, 217.739990, 219.000000, 220.259995, 221.519989, 222.779999, 224.039993, 225.300018, 226.559998, 110.819992, 193.860016, 230.339996, 231.600006, 232.860001, 234.119995, 235.380005, 236.639999, 237.900009, 239.160004, 116.820000, 204.660004, 242.940002, 244.199982, 245.459991, 246.720001, 247.980011, + 249.239990, 250.500000, 251.759995, 122.819992, 47.000000, 26.240004, 26.360004, 26.479998, 26.600002, 26.720001, 26.840002, 26.959997, 27.080000, -12.999998, 257.299988, 337.640015, 339.260010, 340.879974, 342.499969, 344.119995, 345.740021, 347.359985, 348.979980, 198.899994, 249.690002, 299.729980, 301.079987, 302.429993, 303.779999, 305.130005, 306.480011, 307.829987, 309.179993, 153.929993, 261.089996, 313.230011, 314.580017, 315.929993, 317.279968, 318.630005, 319.979980, 321.329987, 322.679993, 160.529999, 272.489990, 326.729980, 328.079987, 329.429993, 330.779968, 332.130005, 333.479980, 334.829987, 336.179993, 167.130005, 283.889984, 340.230011, 341.580017, 342.929993, 344.279999, 345.630005, 346.980011, 348.330017, 349.679993, 173.729996, 295.289978, 353.729980, 355.079987, 356.429993, 357.779968, 359.130005, 360.480011, 361.829987, 363.179993, 180.329987, 306.690002, 367.230011, 368.580017, 369.929993, 371.279999, 372.630005, 373.980011, 375.330017, 376.679993, 186.929993, 318.089996, 380.729980, 382.080017, 383.429993, 384.779968, 386.130005, 387.479980, 388.829987, 390.179993, 193.529984, 329.489990, 394.229980, 395.579987, 396.929993, 398.279999, 399.630005, 400.980011, 402.330017, 403.679993, 200.130005, 82.419998, 55.400005, 55.580002, 55.759995, 55.939999, 56.120003, 56.299995, 56.479996, 56.659996, -9.260002, 393.520020, 518.000000, 519.679993, 521.359985, 523.040039, 524.720032, 526.400024, 528.080017, 529.760010, 303.440002, 382.320007, 462.720032, 464.160004, 465.600037, 467.040009, 468.479980, 469.919983, 471.359985, 472.800018, 239.040009, 394.320007, 477.119995, 478.559998, 480.000000, 481.440002, 482.880005, 484.320007, 485.760010, 487.200012, 246.240005, 406.320007, 491.520020, 492.960022, 494.400024, 495.839996, 497.280029, 498.720032, 500.160004, 501.600037, 253.440002, 418.320007, 505.919983, 507.359985, 508.800018, 510.240051, 511.680023, 513.119995, 514.559998, 516.000000, 260.640015, 430.319977, 520.320007, 521.760010, 523.200012, 524.640015, 526.079956, 527.520020, 528.960022, 530.400024, 267.839996, 442.320007, 534.720032, 536.160034, 537.600037, 539.040039, 540.479980, 541.919983, 543.359985, 544.800049, 275.040009, 454.320007, 549.119995, 550.559998, 552.000000, 553.440002, 554.880005, 556.320007, 557.760010, 559.200012, 282.239990, 466.320007, 563.520020, 564.960022, 566.400024, 567.839966, 569.280029, 570.720032, 572.160034, 573.600037, 289.440002, 125.839996, 96.559998, 96.799995, 97.040009, 97.280014, 97.520004, 97.759995, 98.000000, 98.240013, 2.480007, 537.739990, 710.359985, 712.099976, 713.840027, 715.579956, 717.319946, 719.059998, 720.799988, 722.539978, 415.980011, 526.950012, 643.710022, 645.240051, 646.770020, 648.300049, 649.829956, 651.359985, 652.890015, 654.419983, 336.149994, 539.549988, 659.010010, 660.539978, 662.070007, 663.600037, 665.130005, 666.660034, 668.190002, 669.720032, 343.950012, 552.150024, 674.309998, 675.839966, 677.369995, 678.900024, 680.429993, 681.960022, 683.490051, 685.020020, 351.750000, 564.750000, 689.609985, 691.140015, 692.669983, 694.200012, 695.729980, 697.260010, 698.789978, 700.320007, 359.549988, 577.349976, 704.910034, 706.440002, 707.970032, 709.500000, 711.029968, 712.559998, 714.089966, 715.619995, 367.350037, 589.950012, 720.210022, 721.740051, 723.270020, 724.800049, 726.329956, 727.859985, 729.390015, 730.919983, 375.149994, 602.549988, 735.510010, 737.039978, 738.570007, 740.100037, 741.630005, 743.160034, 744.690002, 746.220032, 382.950012, 615.150024, 750.809998, 752.339966, 753.869995, 755.399963, 756.929993, 758.460022, 759.990051, 761.520020, 390.750000, 177.260010, 149.720001, 150.020004, 150.319992, 150.619995, 150.919998, 151.220001, 151.520004, 151.819992, 22.220009, 689.959961, 914.720032, 916.519958, 918.319946, 920.119995, 921.919983, 923.719971, 925.520020, 927.320007, 536.519958, 683.579956, 842.699951, 844.319946, 845.940002, 847.559998, 849.179993, 850.799988, 852.419983, 854.039978, 445.260010, 696.779968, 858.900024, 860.520020, 862.140015, 863.760010, 865.380005, 867.000000, 868.619995, 870.239990, 453.659973, 709.979980, 875.099976, 876.719971, 878.339966, 879.959961, 881.579956, 883.199951, 884.819946, 886.440002, 462.059998, 723.179993, 891.299988, 892.919983, 894.539978, 896.159973, 897.779968, 899.400024, 901.020020, 902.640015, 470.459991, 736.380005, 907.500000, 909.119995, 910.739990, 912.359985, 913.979980, 915.599976, 917.219971, 918.839966, 478.859985, 749.579956, 923.699951, 925.319946, 926.940002, 928.559998, 930.179993, 931.799988, 933.419983, 935.039978, 487.260010, 762.779968, 939.900024, 941.520020, 943.140015, 944.760010, 946.380005, 948.000000, 949.619995, 951.239990, 495.659973, 775.979980, 956.099976, 957.719971, 959.339966, 960.959961, 962.579956, 964.199951, 965.819946, 967.440002, 504.059998, 236.679977, 214.880005, 215.239990, 215.599991, 215.959991, 216.319992, 216.679993, 217.040009, 217.399994, 49.959995, 850.180054, 1131.079956, 1132.939941, 1134.800049, 1136.660034, 1138.520020, 1140.380005, 1142.239990, 1144.100098, 665.060059, 852.209961, 1059.689941, 1061.399902, 1063.110107, 1064.820068, 1066.530029, 1068.239990, 1069.950073, 1071.660034, 566.370056, 866.010010, 1076.790039, 1078.500000, 1080.209961, 1081.920044, 1083.630005, 1085.339966, 1087.050049, 1088.760010, 575.369995, 879.809998, 1093.890015, 1095.599976, 1097.310059, 1099.020020, 1100.729980, 1102.439941, 1104.149902, 1105.859985, 584.369995, 893.609985, 1110.989990, 1112.699951, 1114.410034, 1116.120117, 1117.830078, 1119.540039, 1121.250000, 1122.959961, 593.370056, 907.410034, 1128.089966, 1129.800049, 1131.510010, 1133.220093, 1134.929932, 1136.639893, 1138.349976, 1140.060059, 602.369995, 921.209961, 1145.189941, 1146.900024, 1148.609985, 1150.320068, 1152.030029, 1153.739990, 1155.449951, 1157.160034, 611.370056, 935.010010, 1162.290039, 1164.000000, 1165.709961, 1167.420044, 1169.130005, 1170.839966, 1172.550049, 1174.260010, 620.369995, 948.809998, 1179.390015, 1181.099976, 1182.810059, 1184.520020, 1186.229980, 1187.939941, 1189.650024, 1191.359985, 629.370056, 304.099976, 292.039978, 292.460022, 292.880005, 293.300018, 293.720001, 294.140015, 294.559998, 294.980042, 85.700005}, nd4j::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + nd4j::ops::depthwise_conv2d op; + ResultSet* results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + NDArray* output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output, 1e-4)); + + delete results; +} + #endif //LIBND4J_CONVOLUTIONTESTS2_H \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu b/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu new file mode 100644 index 000000000..8809ad894 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu @@ -0,0 +1,128 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + +#ifdef HAVE_CUDNN + +#include + +#endif + +using namespace nd4j; + +class CuDnnTests : public testing::Test { +public: + +}; + +static void printer(std::initializer_list helpers) { + + for (auto v:helpers) { + nd4j_printf("Initialized [%s]\n", v->name().c_str()); + } +} + + +TEST_F(CuDnnTests, helpers_includer) { + // we need this block, to make sure all helpers are still available within binary, and not optimized out by linker +#ifdef HAVE_CUDNN + nd4j::ops::platforms::PLATFORM_conv2d_ENGINE_CUDA conv2d; + nd4j::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CUDA conv2d_bp; + nd4j::ops::platforms::PLATFORM_conv3dnew_ENGINE_CUDA conv3dnew; + nd4j::ops::platforms::PLATFORM_conv3dnew_bp_ENGINE_CUDA conv3dnew_bp; + nd4j::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d; + nd4j::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp; + nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm; + + printer({&conv2d}); + printer({&conv2d_bp}); + printer({&conv3dnew}); + printer({&conv3dnew_bp}); + printer({&depthwise_conv2d}); + printer({&depthwise_conv2d_bp}); + printer({&batchnorm}); +#endif +} + + +TEST_F(CuDnnTests, mixed_helpers_test_1) { +#if defined(HAVE_CUDNN) && defined (HAVE_MKLDNN) + nd4j_printf("Mixed platforms test\n", ""); + + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + + auto expOutput = NDArrayFactory::create('c', {bS, oC, oH, oW}, {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); + auto zCUDA = expOutput.like(); + auto zMKL = expOutput.like(); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2,3,1,0}); + + input.syncToHost(); + weights.syncToHost(); + bias.syncToHost(); + + nd4j::ops::conv2d op; + + // cuDNN part + Context cuda(1); + cuda.setTargetEngine(samediff::Engine::ENGINE_CUDA); + cuda.setInputArray(0, &input); + cuda.setInputArray(1, &weights); + cuda.setInputArray(2, &bias); + cuda.setOutputArray(0, &zCUDA); + cuda.setIArguments({kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto statusCUDA = op.execute(&cuda); + + ASSERT_EQ(Status::OK(), statusCUDA); + ASSERT_EQ(expOutput, zCUDA); + + // MKL-DNN part + Context mkl(1); + mkl.setTargetEngine(samediff::Engine::ENGINE_CPU); + mkl.setInputArray(0, &input); + mkl.setInputArray(1, &weights); + mkl.setInputArray(2, &bias); + mkl.setOutputArray(0, &zMKL); + mkl.setIArguments({kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto statusMKL = op.execute(&mkl); + + zMKL.tickWriteHost(); + + ASSERT_EQ(Status::OK(), statusMKL); + ASSERT_EQ(expOutput, zMKL); +#endif +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp b/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp new file mode 100644 index 000000000..03d7bb38e --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace nd4j; +using namespace nd4j::graph; +using namespace nd4j::memory; + +class DataBufferTests : public testing::Test { +public: + +}; + +TEST_F(DataBufferTests, test_alloc_limit_1) { + if (!Environment::getInstance()->isCPU()) + return; + + auto deviceId = AffinityManager::currentDeviceId(); + auto odLimit = MemoryCounter::getInstance()->deviceLimit(deviceId); + auto ogLimit = MemoryCounter::getInstance()->groupLimit(MemoryType::HOST); + auto odUse = MemoryCounter::getInstance()->allocatedDevice(deviceId); + auto ogUse = MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST); + + auto limitSize = odUse + (150 * 1024 * 1024); + auto allocSize = 100000000; + + MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit + limitSize); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, odLimit + limitSize); + + DataBuffer buffer(allocSize, DataType::INT32); + + // separately testing per-device limits and group limits + ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance()->allocatedDevice(deviceId)); + ASSERT_EQ(ogUse + allocSize, MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST)); + + + // setting smaller limits, to make sure next allocation fails with OOM exception + MemoryCounter::getInstance()->setDeviceLimit(deviceId, allocSize - 100); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, allocSize - 100); + + try { + DataBuffer bufferFailed(allocSize, DataType::INT32); + ASSERT_TRUE(false); + } catch (allocation_exception &e) { + // we expect exception here + } + + // restore original limits, so subsequent tests do not fail + MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, odLimit); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu b/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu new file mode 100644 index 000000000..4f309cff5 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu @@ -0,0 +1,89 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace nd4j; +using namespace nd4j::graph; +using namespace nd4j::memory; + +class DataBufferTestsCuda : public testing::Test { +public: + +}; + +/* +TEST_F(DataBufferTestsCuda, test_alloc_limit_1) { + auto deviceId = AffinityManager::currentDeviceId(); + + auto odLimit = MemoryCounter::getInstance()->deviceLimit(deviceId); + + auto opLimit = MemoryCounter::getInstance()->groupLimit(MemoryType::HOST); + auto osLimit = MemoryCounter::getInstance()->groupLimit(MemoryType::DEVICE); + + auto odUse = MemoryCounter::getInstance()->allocatedDevice(deviceId); + + auto opUse = MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST); + auto osUse = MemoryCounter::getInstance()->allocatedGroup(MemoryType::DEVICE); + + auto limitSize = odUse + 150000000; + auto allocSize = 100000000; + + MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit + limitSize); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, opLimit + limitSize); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::DEVICE, osLimit + limitSize); + + DataBuffer buffer(allocSize, DataType::INT32, nullptr, true); + + // separately testing per-device limits and group limits + ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance()->allocatedDevice(deviceId)); + ASSERT_EQ(opUse + allocSize, MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST)); + ASSERT_EQ(osUse + allocSize, MemoryCounter::getInstance()->allocatedGroup(MemoryType::DEVICE)); + + // setting smaller limits, to make sure next allocation fails with OOM exception + MemoryCounter::getInstance()->setDeviceLimit(deviceId, allocSize - 100); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::DEVICE, allocSize - 100); + + + // this allocation should fail, since we're allocating too much + try { + DataBuffer bufferFailed(allocSize + 1, DataType::INT32); + ASSERT_TRUE(false); + } catch (allocation_exception &e) { + // we expect exception here + } + + // + + // restore original limits, so subsequent tests do not fail + MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, opLimit); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::DEVICE, osLimit); +} + */ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index 45b35eb4e..e87dfa125 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -29,8 +29,6 @@ #include #include -using namespace nd4j; - using namespace nd4j; using namespace nd4j::graph; @@ -126,8 +124,8 @@ TEST_F(DataTypesValidationTests, cast_1) { float16 x = static_cast(1.f); float y = static_cast(x); - ASSERT_TRUE(1.f == x); - ASSERT_TRUE(y == x); + ASSERT_TRUE(static_cast(1.f) == x); + ASSERT_TRUE(y == static_cast(x)); } TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 591746804..1e43081c1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -786,7 +786,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) { x.assign(3.f); y.assign(1.f); exp.assign(-2.f); - x.applyTrueBroadcast(BROADCAST(ReverseSubtract), &y, &z, true); + x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); ASSERT_TRUE(exp.equalsTo(&z)); @@ -811,7 +811,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) { x.assign(1); y.assign(3); exp.assign(2); - x.applyTrueBroadcast(BROADCAST(ReverseSubtract), &y, &z, true); + x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); ASSERT_TRUE(z.equalsTo(&exp)); nd4j::ops::reversesubtract subOp; @@ -833,10 +833,10 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) { x.assign(2.); y.assign(9.f); exp.assign(1.f); - y.applyTrueBroadcast(BROADCAST(Mod), &x, &z, true); + y.applyTrueBroadcast(BROADCAST(Mod), x, z, true); ASSERT_TRUE(exp.equalsTo(&z)); - x.applyTrueBroadcast(BROADCAST(ReverseMod), &y, &exp, true); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); ASSERT_TRUE(exp.equalsTo(&z)); nd4j::ops::reversemod subOp; @@ -861,9 +861,9 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) { x.assign(2.f); y.assign(9.f); exp.assign(1.f); - x.applyTrueBroadcast(BROADCAST(ReverseMod), &y, &z, true); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, z, true); ASSERT_TRUE(z.equalsTo(&exp)); - x.applyTrueBroadcast(BROADCAST(ReverseMod), &y, &exp, true); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); ASSERT_TRUE(z.equalsTo(&exp)); nd4j::ops::reversemod subOp; @@ -1218,8 +1218,8 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { ASSERT_TRUE(res->at(0)->equalsTo(exp)); auto z(exp); - x.applyTrueBroadcast(BROADCAST(ReverseDivide), &y, &z, true); - y.applyTrueBroadcast(BROADCAST(Divide), &x, &exp, true); + x.applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); + y.applyTrueBroadcast(BROADCAST(Divide), x, exp, true); ASSERT_TRUE(z.equalsTo(&exp)); @@ -1759,7 +1759,7 @@ TEST_F(DeclarableOpsTests1, Transpose1) { Nd4jStatus status = transpose.execute(block); ASSERT_EQ(ND4J_STATUS_OK, status); - // ASSERT_TRUE(x.isSameShapeStrict(&exp)); + // ASSERT_TRUE(x.isSameShapeStrict(exp)); for (int e = 0; e < x->rankOf() * 2 + 2; e++) { ASSERT_EQ(x->getShapeInfo()[e], exp->getShapeInfo()[e]); @@ -1790,7 +1790,7 @@ TEST_F(DeclarableOpsTests1, Transpose2) { ASSERT_EQ(ND4J_STATUS_OK, status); auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // ASSERT_TRUE(result->isSameShapeStrict(&exp)); + // ASSERT_TRUE(result->isSameShapeStrict(exp)); for (int e = 0; e < result->rankOf() * 2 + 2; e++) { ASSERT_EQ(result->getShapeInfo()[e], exp->getShapeInfo()[e]); } @@ -1828,7 +1828,7 @@ TEST_F(DeclarableOpsTests1, Permute1) { Nd4jStatus status = permute.execute(block); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(x->isSameShapeStrict(exp)); + ASSERT_TRUE(x->isSameShapeStrict(*exp)); delete exp; delete block; @@ -1863,7 +1863,7 @@ TEST_F(DeclarableOpsTests1, Permute2) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(result->isSameShapeStrict(exp)); + ASSERT_TRUE(result->isSameShapeStrict(*exp)); delete block; delete variableSpace; @@ -2468,7 +2468,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { NDArray expGradX('c', {N,bS,2*K}, expGradXBuff); NDArray expGradW('c', {N,2*K,6*K}, expGradWBuff); auto expGradB = NDArrayFactory::create('c', {4*K}); - gradBias.reduceAlongDimension(reduce::Sum, &expGradB, {0}); // [bS, 4K] -> [4K] + gradBias.reduceAlongDimension(reduce::Sum, expGradB, {0}); // [bS, 4K] -> [4K] NDArray expGradInit('c', {bS,2*K}, expGradInitBuff); input.assign(1.5); @@ -2827,7 +2827,7 @@ TEST_F(DeclarableOpsTests1, Stack_1) { auto results = op.execute({&input1, &input2}, {}, {0}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2855,7 +2855,7 @@ TEST_F(DeclarableOpsTests1, Stack_2) { auto results = op.execute({&input1, &input2}, {}, {1}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2883,7 +2883,7 @@ TEST_F(DeclarableOpsTests1, Stack_3) { auto results = op.execute({&input1, &input2}, {}, {0}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2910,7 +2910,7 @@ TEST_F(DeclarableOpsTests1, Stack_4) { auto results = op.execute({&input1, &input2}, {}, {1}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2937,7 +2937,7 @@ TEST_F(DeclarableOpsTests1, Stack_5) { auto results = op.execute({&input1, &input2}, {}, {0}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2964,7 +2964,7 @@ TEST_F(DeclarableOpsTests1, Stack_6) { auto results = op.execute({&input1, &input2}, {}, {1}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2988,7 +2988,7 @@ TEST_F(DeclarableOpsTests1, Stack_7) { auto results = op.execute({&input1, &input1, &input1}, {}, {0}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3011,7 +3011,7 @@ TEST_F(DeclarableOpsTests1, Stack_8) { auto results = op.execute({&input1, &input1, &input1}, {}, {0}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3034,7 +3034,7 @@ TEST_F(DeclarableOpsTests1, Stack_9) { auto results = op.execute({&input1, &input1, &input1}, {}, {1}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3060,7 +3060,7 @@ TEST_F(DeclarableOpsTests1, Stack_10) { //expected.printShapeInfo("exp"); //output->printShapeInfo("out"); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3082,7 +3082,7 @@ TEST_F(DeclarableOpsTests1, Stack_11) { auto results = op.execute({&input1, &input1, &input1}, {}, {}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3370,7 +3370,7 @@ TEST_F(DeclarableOpsTests1, Reverse_1 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3395,7 +3395,7 @@ TEST_F(DeclarableOpsTests1, Reverse_2 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(&input)); + ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.equalsTo(&input)); delete results; @@ -3421,7 +3421,7 @@ TEST_F(DeclarableOpsTests1, Reverse_3 ) { auto result = results->at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3447,7 +3447,7 @@ TEST_F(DeclarableOpsTests1, Reverse_4 ) { auto result = results->at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3472,7 +3472,7 @@ TEST_F(DeclarableOpsTests1, Reverse_5 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3498,7 +3498,7 @@ TEST_F(DeclarableOpsTests1, Reverse_6 ) { auto result = results->at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(&input)); + ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.equalsTo(&input)); delete results; @@ -3526,7 +3526,7 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) { //expected.printIndexedBuffer("E"); //result->printIndexedBuffer("R"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3554,7 +3554,7 @@ TEST_F(DeclarableOpsTests1, Reverse_8 ) { auto result = results->at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3579,7 +3579,7 @@ TEST_F(DeclarableOpsTests1, Reverse_9 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3618,7 +3618,7 @@ TEST_F(DeclarableOpsTests1, Reverse_11 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3640,7 +3640,7 @@ TEST_F(DeclarableOpsTests1, Reverse_12 ) { auto result = results->at(0); //result->printIndexedBuffer("Result reverse"); //expected.printIndexedBuffer("Expected reverse"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3661,7 +3661,7 @@ TEST_F(DeclarableOpsTests1, Reverse_13 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3682,7 +3682,7 @@ TEST_F(DeclarableOpsTests1, Reverse_14 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 21c18299e..689969543 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -579,6 +579,29 @@ TEST_F(DeclarableOpsTests10, IGamma_Test2) { delete result; } +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LGamma_Test1) { + + auto x = NDArrayFactory::create('c', {3, 3}, {0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}); + + auto exp = NDArrayFactory::create('c', {3,3}, { + 2.2527127 , 0.5723649 , 0.26086727, + -0.12078223, -0.09580769, 0., + 0.28468287, 0.4348206 , 0.6931472 + }); + + nd4j::ops::lgamma op; + auto result = op.execute({&x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test10) { @@ -2356,7 +2379,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2378,7 +2401,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) { auto result = results->at(0); // result->printIndexedBuffer("REDUCE_LOGSUMEXP"); // expected.printIndexedBuffer("LSE EXPECTED"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2398,7 +2421,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) { auto result = results->at(0); // result->printIndexedBuffer("REDUCE_LOGSUMEXP"); // expected.printIndexedBuffer("LSE EXPECTED"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2419,7 +2442,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { NDArray* result = results->at(0); //result->printIndexedBuffer("OOOOUUUUTTT"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2440,7 +2463,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput2"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2462,7 +2485,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput3"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2485,7 +2508,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput4"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2507,7 +2530,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput4"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2531,7 +2554,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput6"); // result->printShapeInfo("Ouput6 shape is"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2555,7 +2578,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput06"); // result->printShapeInfo("Ouput06 shape is"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2602,7 +2625,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppressionOverlap1 Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2627,7 +2650,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2652,7 +2675,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2677,7 +2700,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { auto result = results->at(0); // result->printIndexedBuffer("Cropped and Resized"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2701,7 +2724,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2725,7 +2748,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2749,7 +2772,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { auto result = results->at(0); // result->printIndexedBuffer("Cropped and Resized"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2773,7 +2796,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); //ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2811,7 +2834,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { result->syncToHost(); // result->printBuffer("Bounded boxes"); // expected.printBuffer("Bounded expec"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2844,7 +2867,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { // result->syncToHost(); // result->printBuffer("Bounded boxes 2"); // expected.printBuffer("Bounded expec 2"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2899,7 +2922,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { // result->syncToHost(); // result->printBuffer("Bounded boxes 2"); // expected.printBuffer("Bounded expec 2"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2921,7 +2944,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { auto result = results->at(0); // result->printBuffer("Quantized"); // exp.printBuffer("Expected"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -2941,7 +2964,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) { auto result = results->at(0); // result->printIndexedBuffer("Quantized2"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -2962,7 +2985,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { auto result = results->at(0); // result->printIndexedBuffer("Quantized2"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -2986,7 +3009,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) { auto result = results->at(0); // result->printIndexedBuffer("Quantized03"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3009,7 +3032,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) { auto result = results->at(0); // result->printIndexedBuffer("Quantized03_1"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3033,7 +3056,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) { auto result = results->at(0); result->printIndexedBuffer("Quantized03_2"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3056,7 +3079,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) { auto result = results->at(0); result->printIndexedBuffer("Quantized03_3"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3094,7 +3117,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { // exp.printBuffer("Quantized per channest E"); // auto diff = *result - exp; // diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3148,7 +3171,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { // auto diff = *result - exp; // diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3182,7 +3205,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) { // auto diff = *result - exp; // diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3225,7 +3248,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) { auto result = results->at(0); // result->printBuffer("Quantized7"); // exp.printBuffer("Expected 7"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3251,186 +3274,12 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) { // x.printBuffer("SourInput8"); // result->printBuffer("Quantized8"); // exp.printBuffer("Expected 8"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; } -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, batchnorm_test1) { - - NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); - - NDArray expected('c', {2,4}, {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) { - - auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {4}); - auto variance = NDArrayFactory::create('c', {4}); - auto gamma = NDArrayFactory::create('c', {4}); - auto beta = NDArrayFactory::create('c', {4}); - - auto expected = NDArrayFactory::create('c', {2,3,4}, {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, 1.16970393f, 1.33940786f, - 1.50911179f, 1.67881572f, 1.84851965f, 2.01822358f, 2.18792751f, 2.35763144f, 2.52733537f, 2.6970393f, 2.86674323f, 3.03644717f, 3.2061511f, 3.37585503f}); - - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) { - - auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {3}, {1.05f, 1.1f, 1.15f}); - auto variance = NDArrayFactory::create('c', {3}, {0.5f, 0.6f, 0.7f}); - auto gamma = NDArrayFactory::create('c', {3}, {1.2f, 1.3f, 1.4f}); - auto beta = NDArrayFactory::create('c', {3}, {0.1f, 0.2f, 0.3f}); - - auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, 0.21633459f, 0.38366541f, - 0.52425983f, 0.69396376f, 0.86366769f, 1.03337162f, 1.20696728f, 1.37479516f, 1.54262304f, 1.71045092f, 1.8896427f, 2.05697351f, 2.22430432f, 2.39163513f}); - - input.linspace(0.1, 0.1); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) { - - auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {2,1,4}, {1.05f, 1.1f, 1.15f, 1.2f, 1.25f, 1.3f, 1.35f, 1.4f}); - auto variance = NDArrayFactory::create('c', {2,1,4}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); - auto gamma = NDArrayFactory::create('c', {2,1,4}, {1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f}); - auto beta = NDArrayFactory::create('c', {2,1,4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.66f, 0.7f, 0.8f}); - - auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, 0.21633459f, 0.4f, - 0.58432694f, 0.82999915f, 0.95743373f, 1.14688951f, 1.25894242f, 1.50999575f, 1.64392367f, 1.84066852f, 1.93355791f, 2.18999235f, 2.33041362f, 2.53444754f}); - - input.linspace(0.1, 0.1); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, batchnorm_test5) { - - NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); - - NDArray expected('c', {2,4,2,2}, { 11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, -9.852428f, -10.f, -20.f, - -19.856981f, -19.713963f, -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f, - -12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, -17.425663f, -17.282644f}, nd4j::DataType::FLOAT32); - input.linspace(0.1, 0.1); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, batchnorm_test6) { - - NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); - - NDArray expected('c', {2,2,2,4}, {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, -9.852428f, -20.143019f, 9.57574f, - 20.388447f, -10.442716f, -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, 22.25299f, - -12.213582f, -17.854719f, 6.860477f, 22.874504f, -12.803871f, -17.282644f}, nd4j::DataType::FLOAT32); - input.linspace(0.1, 0.1); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { @@ -3441,7 +3290,7 @@ TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { NDArray result('c', {2,2,2}, nd4j::DataType::BOOL); - arr1.applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), &arr2, &result, true, nullptr); + arr1.applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), arr2, result, true); // result.printIndexedBuffer(); // expd.printIndexedBuffer(); @@ -3474,7 +3323,7 @@ TEST_F(DeclarableOpsTests10, printIndexedTest_1) { // [[5 6] // [7 8]]] // - ResultSet* lastDims = arr.allTensorsAlongDimension({3}); // last dim + ResultSet lastDims = arr.allTensorsAlongDimension({3}); // last dim size_t k = 0; // k from 0 to lastDims->size() Nd4jLong rank = 4; // in this case printf("["); @@ -3488,15 +3337,13 @@ TEST_F(DeclarableOpsTests10, printIndexedTest_1) { // printf("["); // else // printf(" "); - lastDims->at(k++)->printBuffer(); + lastDims.at(k++)->printBuffer(); //if (k == arr.sizeAt(i)) // printf("]\n"); } printf("]\n"); } printf("]\n"); - delete lastDims; - } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 647f37271..9b1dfc068 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -975,6 +975,551 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) { delete results; } +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) { + + NDArray input = NDArrayFactory::create('c', {2, 5, 5, 1}, { + 0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, 0.9511, + 0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, 0.1524, 0.2613, + 0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, 0.3741, 0.2088, 0.2989, + 0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, 0.2869, 0.6638, 0.3091, 0.9312, + 0.0240, 0.2893, 0.5632, 0.9625, 0.4189, 0.3854, 0.2743, 0.6754, 0.8820, 0.8699}); + + NDArray expected = NDArrayFactory::create('c', {2, 9, 9, 1}, { + 0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f, + 0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f, + 0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f, + 0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f, + 0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f, + 0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f, + 0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f, + 0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f, + 0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f, + 0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f, + 0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f, + 0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f, + 0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f, + 0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f, + 0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f, + 0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f, + 0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f, + 0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f, + 0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f, + 0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f, + 0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f, + 0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f, + 0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f, + 0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f, + 0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f, + 0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f, + 0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f, + 0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f, + 0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f, + 0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f, + 0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f, + 0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f, + 0.9139262f, 0.92068815f + }); + auto size = NDArrayFactory::create({9, 9}); + nd4j::ops::resize_bicubic op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 9x9"); +// expected.printBuffer("Expect for 9x9"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { + + NDArray input = NDArrayFactory::create('c', {2, 5, 5, 1}, { + 0.23028551377579154, 0.7949972231516509, 0.8171307820461517, 0.04507309923418412, 0.3689673597428338, + 0.6845757584903018, 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, 0.9511201914609859, + 0.41160882670429033, 0.3997152563642703, 0.4074505147711718, 0.6274595060113246, 0.8017922711300232, + 0.06782045852179475, 0.6220772280691722, 0.2982335327629251, 0.1523603480424196, 0.2612986044295986, + 0.7424762244324299, 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, 0.13607200219168547, + 0.4154002170215956, 0.36340617544852116, 0.37405031188276827, 0.20880251686544882, 0.298919946410666, + 0.39820758164277126, 0.5617728968896589, 0.72660225993937, 0.10888245916813699, 0.29215797784445496, + 0.3305531351746034, 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, 0.9312186188801752, + 0.0239594182399363, 0.2892942758780874, 0.5631691110629038, 0.9625499752246309, 0.4189439089689968, + 0.3854304088214935, 0.27426304203925045, 0.6754051704648238, 0.8820362490795286, 0.8699337744328859}); + + + auto testData = NDArrayFactory::create('c', {2,9,9,1}, { + 0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f, + 0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f, + 0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f, + 0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f, + 0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f, + 0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f, + 0.06782046f, 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, 0.15236035f, 0.193153858f, 0.261298597f, + + 0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f, + 0.742476225f, 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, 0.583827078f, 0.340373576f, 0.136071995f, + 0.415400207f, 0.388405323f, 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, 0.238369256f, 0.298919946f, + 0.413146496f, 0.444389015f, 0.488355637f, 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, 0.23562704f, + 0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f, + 0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f, + + 0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f, + 0.137025774f, 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, 0.662163854f, 0.68354249f, 0.692712903f, + 0.023959421f, 0.130951077f, 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, 0.728188932f, 0.418943912f, + 0.175951749f, 0.198239252f, 0.281999886f, 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, 0.596413136f, + 0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f + }); + + auto size = NDArrayFactory::create({9, 9}); + nd4j::ops::resize_bicubic op; + auto results = op.execute({&input, &size}, {}, {}, {true, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 9x9"); +// testData.printBuffer("Expect for 9x9"); + ASSERT_TRUE(testData.isSameShape(result)); + ASSERT_TRUE(testData.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 4}, { + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 9.f, 10.f, 11.f, 12.f, + + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 9.f, 10.f, 11.f, 12.f, + + 13.f, 14.f, 15.f, 16.f, + 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, + 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, + 21.f, 22.f, 23.f, 24.f, + + 13.f, 14.f, 15.f, 16.f, + 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, + 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, + 21.f, 22.f, 23.f, 24.f, + + 25.f, 26.f, 27.f, 28.f, + 25.f, 26.f, 27.f, 28.f, + 29.f, 30.f, 31.f, 32.f, + 29.f, 30.f, 31.f, 32.f, + 33.f, 34.f, 35.f, 36.f, + 33.f, 34.f, 35.f, 36.f, + + 25.f, 26.f, 27.f, 28.f, + 25.f, 26.f, 27.f, 28.f, + 29.f, 30.f, 31.f, 32.f, + 29.f, 30.f, 31.f, 32.f, + 33.f, 34.f, 35.f, 36.f, + 33.f, 34.f, 35.f, 36.f }); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_area op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}); + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { + 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, + 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, + 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, + 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, + 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, + 7.f, 7.f, 8.f, 8.f, 9.f, 9.f + }); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_area op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 3}); + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 3}, { + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f + }); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_area op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 3, 3}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27 + }); + + NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 3}, { + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f + }); + //input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_area op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 3, 3}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27 + }); + + NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 3}, { + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f + }); + //input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_area op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 3, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, + 1, 2, 3, 4, 5, 6, 7, 8, 9 + }); + + NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 1}, { + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, + + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f + }); + //input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_area op; + auto results = op.execute({&input, &size}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 3, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, + 1, 2, 3, 4, 5, 6, 7, 8, 9 + }); + + NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 1}, { + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, + + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f + }); + //input.linspace(1); +// auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_area op; + auto results = op.execute({&input}, {}, {6, 6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9 + }); + + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f + }); + //input.linspace(1); +// auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_area op; + auto results = op.execute({&input}, {}, {6, 6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 + }); + + NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, { + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999989f, 21.999989f, 22.999987f, 23.999987f + + }); + //input.linspace(1); + auto size = NDArrayFactory::create({10, 10}); + nd4j::ops::resize_area op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 10x10"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 + }); + + NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, { + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999989f, 21.999989f, 22.999987f, 23.999987f + + }); + //input.linspace(1); + //auto size = NDArrayFactory::create({10, 10}); + nd4j::ops::resize_area op; + auto results = op.execute({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 10x10"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 + }); + +// NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { +// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 +// +// }); + //input.linspace(1); + //auto size = NDArrayFactory::create({10, 10}); + nd4j::ops::resize_area op; + auto results = op.execute({&input}, {}, {6, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 6x9"); + // expected.printBuffer("Area Expect for 6x6"); +// ASSERT_TRUE(expected.isSameShape(result)); +// ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 + }); + +// NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { +// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 +// +// }); + //input.linspace(1); + //auto size = NDArrayFactory::create({10, 10}); + nd4j::ops::resize_area op; + auto results = op.execute({&input}, {}, {10, 15}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 6x9"); + // expected.printBuffer("Area Expect for 6x6"); +// ASSERT_TRUE(expected.isSameShape(result)); +// ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 + }); + +// NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 4}, { +// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 +// +// }); + //input.linspace(1); + //auto size = NDArrayFactory::create({10, 10}); + nd4j::ops::resize_area op; + auto results = op.execute({&input}, {}, {9, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Area Resized to 8x8"); + // expected.printBuffer("Area Expect for 6x6"); +// ASSERT_TRUE(expected.isSameShape(result)); +// ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { @@ -2673,11 +3218,11 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) { TEST_F(DeclarableOpsTests11, SafeDivideMixed_Test1) { NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); - auto sumDiff = labels.reduceAlongDims(reduce::Sum, {1}, true); + auto sumDiff = labels.reduceAlongDimension(reduce::Sum, {1}, true); NDArray numOfNonZero(sumDiff.getShapeInfo(), nd4j::DataType::INT64, false); numOfNonZero.assign(1); - sumDiff.applyPairwiseTransform(pairwise::SafeDivide, &numOfNonZero, &sumDiff, nullptr); + sumDiff.applyPairwiseTransform(pairwise::SafeDivide, numOfNonZero, sumDiff); } ///////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 5ca22c95e..c0ce9f1ab 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -597,7 +597,7 @@ TEST_F(DeclarableOpsTests12, reverse_test15) { TEST_F(DeclarableOpsTests12, mirrorPad_test17) { NDArray x('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE); - NDArray padding('c', {2,2}, {1,1,2,2}, nd4j::DataType::INT32); + NDArray padding('c', {2,2}, {1,1,2,2}, nd4j::DataType::INT64); NDArray z('c', {4,7}, nd4j::DataType::DOUBLE); NDArray exp1('c', {4,7}, {6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1,6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1}, nd4j::DataType::DOUBLE); NDArray exp2('c', {4,7}, {2, 1, 1, 2, 3, 3, 2,2, 1, 1, 2, 3, 3, 2,5, 4, 4, 5, 6, 6, 5,5, 4, 4, 5, 6, 6, 5}, nd4j::DataType::DOUBLE); @@ -621,7 +621,7 @@ TEST_F(DeclarableOpsTests12, mirrorPad_test17) { TEST_F(DeclarableOpsTests12, mirrorPad_test18) { NDArray x('c', {3}, {1,2,3}, nd4j::DataType::DOUBLE); - NDArray padding('c', {2}, {1,1}, nd4j::DataType::INT32); + NDArray padding('c', {1, 2}, {1,1}, nd4j::DataType::INT32); NDArray z('c', {5}, nd4j::DataType::DOUBLE); NDArray exp('c', {5}, {2,1,2,3,2}, nd4j::DataType::DOUBLE); @@ -670,7 +670,7 @@ TEST_F(DeclarableOpsTests12, relu_1) { Nd4jStatus status = op.execute({&input}, {&z}, {0}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(&z)); + ASSERT_TRUE(expected.isSameShapeStrict(z)); ASSERT_TRUE(expected.equalsTo(z)); } @@ -810,9 +810,10 @@ TEST_F(DeclarableOpsTests12, pullRows_1) { #ifdef __CUDABLAS__ nativeStart[1] = (x.getContext()->getCudaStream()); #endif - - pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), - z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.getSpecialShapeInfo(), + &zBuf, z.getShapeInfo(), z.specialShapeInfo(), 4, pidx, xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); @@ -825,7 +826,7 @@ TEST_F(DeclarableOpsTests12, pullRows_1) { TEST_F(DeclarableOpsTests12, pullRows_2) { NDArray arr('f', {5, 2}, {0,1,2,3,4,5,6,7,8,9}); - NDArray* y = arr.dup('c'); + NDArray* y = new NDArray(arr.dup('c')); NDArray x = (*y)({0,0, 0,1}, true); // view, points on first column of y, shape is {5,1} NDArray z('c', {4, 1}, nd4j::DataType::DOUBLE); @@ -844,8 +845,10 @@ TEST_F(DeclarableOpsTests12, pullRows_2) { #ifdef __CUDABLAS__ nativeStart[1] = (x.getContext()->getCudaStream()); #endif - pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.specialShapeInfo(), + &zBuf, z.getShapeInfo(), z.specialShapeInfo(), 4, pidx, xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); @@ -858,7 +861,7 @@ TEST_F(DeclarableOpsTests12, pullRows_2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, softmax_9) { NDArray arrC('c', {5,2}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 1}, nd4j::DataType::FLOAT32); - NDArray* arrF = arrC.dup('f'); + NDArray* arrF = new NDArray(arrC.dup('f')); NDArray outCC('c', {5,2}, nd4j::DataType::FLOAT32); NDArray outCF('f', {5,2}, nd4j::DataType::FLOAT32); @@ -1395,7 +1398,7 @@ TEST_F(DeclarableOpsTests12, pad_tests1) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1422,7 +1425,7 @@ TEST_F(DeclarableOpsTests12, pad_tests2) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1434,11 +1437,11 @@ TEST_F(DeclarableOpsTests12, pad_tests2) { TEST_F(DeclarableOpsTests12, pad_tests3) { float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - int padBuff[] = {1,1,2,2}; + Nd4jLong padBuff[] = {1,1,2,2}; float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f,1.f,1.f,2.f,3.f,3.f,2.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; @@ -1449,7 +1452,7 @@ TEST_F(DeclarableOpsTests12, pad_tests3) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1462,10 +1465,10 @@ TEST_F(DeclarableOpsTests12, pad_tests4) { float inBuff[] = {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f}; int padBuff[] = {1,1,2,2,2,2}; - float expBuff[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, - 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, 0.f, - 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + float expBuff[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, + 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, 0.f, + 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); @@ -1480,7 +1483,7 @@ TEST_F(DeclarableOpsTests12, pad_tests4) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); // for(int i = 0; i < expected.lengthOf(); ++i) { @@ -1514,7 +1517,7 @@ TEST_F(DeclarableOpsTests12, pad_tests5) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1541,7 +1544,7 @@ TEST_F(DeclarableOpsTests12, pad_tests6) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1567,7 +1570,7 @@ TEST_F(DeclarableOpsTests12, pad_tests7) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1593,7 +1596,7 @@ TEST_F(DeclarableOpsTests12, pad_tests8) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1619,7 +1622,7 @@ TEST_F(DeclarableOpsTests12, pad_tests9) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1641,7 +1644,7 @@ TEST_F(DeclarableOpsTests12, pad_tests10) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1663,7 +1666,7 @@ TEST_F(DeclarableOpsTests12, pad_tests11) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1692,7 +1695,7 @@ TEST_F(DeclarableOpsTests12, pad_tests12) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1714,7 +1717,7 @@ TEST_F(DeclarableOpsTests12, pad_tests13) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1735,7 +1738,7 @@ TEST_F(DeclarableOpsTests12, pad_tests14) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1756,7 +1759,7 @@ TEST_F(DeclarableOpsTests12, pad_tests15) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1777,7 +1780,7 @@ TEST_F(DeclarableOpsTests12, pad_tests16) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1798,7 +1801,7 @@ TEST_F(DeclarableOpsTests12, pad_tests17) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1819,7 +1822,7 @@ TEST_F(DeclarableOpsTests12, pad_tests18) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1840,7 +1843,7 @@ TEST_F(DeclarableOpsTests12, pad_tests19) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1861,7 +1864,7 @@ TEST_F(DeclarableOpsTests12, pad_tests20) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1884,7 +1887,7 @@ TEST_F(DeclarableOpsTests12, pad_tests21) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1907,7 +1910,7 @@ TEST_F(DeclarableOpsTests12, pad_tests22) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1931,7 +1934,7 @@ TEST_F(DeclarableOpsTests12, pad_tests23) { // result->printShapeInfo("r"); // expected.printShapeInfo("e"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1953,7 +1956,7 @@ TEST_F(DeclarableOpsTests12, pad_tests24) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1975,7 +1978,7 @@ TEST_F(DeclarableOpsTests12, pad_tests25) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1997,7 +2000,7 @@ TEST_F(DeclarableOpsTests12, pad_tests26) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2017,7 +2020,7 @@ TEST_F(DeclarableOpsTests12, pad_tests27) { // z.printIndexedBuffer(); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(exp.isSameShapeStrict(&z)); + ASSERT_TRUE(exp.isSameShapeStrict(z)); ASSERT_TRUE(exp.equalsTo(z)); } @@ -2143,7 +2146,7 @@ TEST_F(DeclarableOpsTests12, pad_tests34) { Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {10}, {0}, {}); // constant ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(&z)); + ASSERT_TRUE(expected.isSameShapeStrict(z)); ASSERT_TRUE(expected.equalsTo(z)); } @@ -2167,7 +2170,7 @@ TEST_F(DeclarableOpsTests12, Pad_1) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2194,7 +2197,7 @@ TEST_F(DeclarableOpsTests12, Pad_2) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2221,7 +2224,7 @@ TEST_F(DeclarableOpsTests12, Pad_3) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2248,7 +2251,7 @@ TEST_F(DeclarableOpsTests12, Pad_4) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2275,7 +2278,7 @@ TEST_F(DeclarableOpsTests12, Pad_5) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2302,7 +2305,7 @@ TEST_F(DeclarableOpsTests12, Pad_6) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2328,7 +2331,7 @@ TEST_F(DeclarableOpsTests12, Pad_7) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2354,7 +2357,7 @@ TEST_F(DeclarableOpsTests12, Pad_8) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2380,7 +2383,7 @@ TEST_F(DeclarableOpsTests12, Pad_9) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2424,3 +2427,584 @@ TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) { ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_1) { + + auto in = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.}); + auto exp = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7}); + auto pExp = NDArrayFactory::create('c', {3}, {0, 1, 2}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars"); +// p->printIndexedBuffer("Permutaions"); + + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(pExp.equalsTo(p)); + + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_2) { + auto in = NDArrayFactory::create('c', {3,3}, {1, 0, 0, 2, 3, 0, 4, 5, 6}); + + auto expLU = NDArrayFactory::create('c', {3,3}, {4., 5., 6., 0.25, -1.25, -1.5, 0.5, -0.4, -3.6}); + auto expP = NDArrayFactory::create({2, 0, 1}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars2"); +// p->printIndexedBuffer("Permutaions2"); + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_3) { + auto in = NDArrayFactory::create('c', {3,3}, {1,2,3,4,7,9, 11, 12, 13}); + + auto expLU = NDArrayFactory::create('c', {3,3}, { + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753}); + + auto expP = NDArrayFactory::create({2, 1, 0}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars3"); +// p->printIndexedBuffer("Permutaions3"); + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_4) { + + auto in = NDArrayFactory::create('c', {10,10}, { + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., + 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., + 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., + 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., + 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., + 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., + 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + + auto expLU = NDArrayFactory::create('c', {10,10}, { + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, + 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, + 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, + 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, + 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, + 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, + 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695 + }); + + auto expP = NDArrayFactory::create({1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printBuffer("Triangulars4"); +// expLU.printBuffer("TriangulExp4"); +// p->printBuffer("Permutaions4"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +TEST_F(DeclarableOpsTests12, LU_Test_5) { + + auto in = NDArrayFactory::create('c', {2, 10,10}, { + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., + 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., + 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., + 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., + 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., + 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., + 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., + 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., + 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., + 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., + 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., + 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., + 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. + }); + + auto expLU = NDArrayFactory::create('c', {2, 10,10}, { + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, + 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, + 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, + 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, + 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, + 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, + 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695, + + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, + 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, + 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, + 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, + 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, + 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, + 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695 + + }); + + auto expP = NDArrayFactory::create('c', {2, 10}, { + 1, 2, 7, 3, 6, 8, 5, 4, 0, 9, + 1, 2, 7, 3, 6, 8, 5, 4, 0, 9 + }); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printBuffer("Triangulars5"); +// expLU.printBuffer("TriangulExp5"); +// p->printBuffer("Permutaions5"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_1_2) { + + auto in = NDArrayFactory::create('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.,1., 2., 3., 0., 2., 3., 0., 0., 7.}); + auto exp = NDArrayFactory::create('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7, 1., 2., 3., 0., 2., 3., 0., 0., 7.}); + + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars (2,3,3)"); +// p->printIndexedBuffer("Permutaions (2,3,3)"); + ASSERT_TRUE(exp.equalsTo(res->at(0))); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_3_2) { + + auto in = NDArrayFactory::create('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,1,2,3,4,7,9, 11, 12, 13}); + + auto expLU = NDArrayFactory::create('c', {2, 3,3}, { + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753, + + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753 + }); + + auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 2, 1, 0}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars3_2"); +// p->printIndexedBuffer("Permutaions3_2"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_3_3) { + + auto in = NDArrayFactory::create('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,13,2,3,4,7,9, 11, 12, 1}); + auto expLU = NDArrayFactory::create('c', {2, 3,3}, { + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753, + + 13., 2., 3., + 0.84615386, 10.307693, -1.5384617, + 0.30769232, 0.619403, 9.029851}); + + auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 0, 2, 1}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars3_3"); +// p->printIndexedBuffer("Permutaions3_3"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_4_1) { + + auto in = NDArrayFactory::create('c', {2, 2,2}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f + }); + + auto expLU = NDArrayFactory::create('c', {2, 2,2}, { + 0.7788f, 0.8012f, 0.930149f, -0.514335f, + 0.7271f, 0.1804f, 0.695365f, 0.767056f + }); + + auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars4_1"); +// p->printIndexedBuffer("Permutaions4_1"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_4_2) { + + auto in = NDArrayFactory::create('c', {2, 2,2}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f + }); + + auto expLU = NDArrayFactory::create('c', {2, 2,2}, { + 0.7788f, 0.8012f, 0.930149f, -0.514335f, + 0.7271f, 0.1804f, 0.695365f, 0.767056f + }); + + auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {nd4j::DataType::INT64}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars4_2"); +// p->printIndexedBuffer("Permutaions4_2"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, QR_Test_1) { + + auto in = NDArrayFactory::create('c', {5,3}, { + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. + }); + auto expQ = NDArrayFactory::create('c', {5, 5}, { + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485 + }); + + auto expR = NDArrayFactory::create('c', {5,3}, { + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. }); + nd4j::ops::qr op; + auto res = op.execute({&in}, {}, {}, {true}); + + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto q = res->at(0); + auto r = res->at(1); +// q->printIndexedBuffer("Orthogonal 5x5"); +// expQ.printBuffer("Orthogonal Exp"); +// r->printIndexedBuffer("Upper triangular 5x3"); +// expR.printBuffer("Upper triangular Exp"); +// q->printShapeInfo("Q shape"); +// r->printShapeInfo("R shape"); + nd4j::ops::matmul opMul; + auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2->at(0);//->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp->isSameShape(in)); +// ASSERT_TRUE(q->isSameShape(expQ)); + + //ASSERT_TRUE(expQ.equalsTo(q)); + ASSERT_TRUE(exp->equalsTo(in)); + delete res2; + delete res; + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, QR_Test_1_1) { + + auto in = NDArrayFactory::create('c', {4, 5, 3}, { + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. + }); + auto expQ = NDArrayFactory::create('c', {4, 5, 5}, { + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485 + }); + + auto expR = NDArrayFactory::create('c', {4, 5,3}, { + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. + }); + nd4j::ops::qr op; + auto res = op.execute({&in}, {}, {}, {true}); + + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto q = res->at(0); + auto r = res->at(1); +// q->printIndexedBuffer("Orthogonal 5x5"); +// expQ.printBuffer("Orthogonal Exp"); +// r->printIndexedBuffer("Upper triangular 5x3"); +// expR.printBuffer("Upper triangular Exp"); +// q->printShapeInfo("Q shape"); +// r->printShapeInfo("R shape"); + nd4j::ops::matmul opMul; + auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2->at(0);//->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp->isSameShape(in)); +// ASSERT_TRUE(q->isSameShape(expQ)); + + //ASSERT_TRUE(expQ.equalsTo(q)); + ASSERT_TRUE(exp->equalsTo(in)); + delete res2; + delete res; + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, QR_Test_2) { + + auto in = NDArrayFactory::create('c', {5,3}, { + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. + }); + auto expQ = NDArrayFactory::create('c', {5, 3}, { + 0.8464148, 0.3912908, -0.3431241, -0.42320737, -0.9040873, 0.02927014, 0.28213826, -0.17042054, -0.93285596, 0.07053456, -0.01404065, 0.00109937, -0.14106913, 0.0166551, 0.10577161 + }); + + auto expR = NDArrayFactory::create('c', {3,3}, { + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546 + }); + + nd4j::ops::qr op; + auto res = op.execute({&in}, {}, {}, {false}); + + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto q = res->at(0); + auto r = res->at(1); + ASSERT_TRUE(q->isSameShape(expQ)); + ASSERT_TRUE(r->isSameShape(expR)); +// q->printIndexedBuffer("Orthogonal 5x5"); +// r->printIndexedBuffer("Upper triangular 5x3"); + + nd4j::ops::matmul opMul; + auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2->at(0);//->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp->isSameShape(in)); + ASSERT_TRUE(exp->equalsTo(in)); + delete res2; + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_1) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 4.f, 2.f, 4.f, 2.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f }); + + nd4j::ops::triangular_solve op; + + auto res = op.execute({&a, &b}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_2) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 2.f, 4.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 1.f, 1.3333333f }); + + nd4j::ops::triangular_solve op; + + auto res = op.execute({&a, &b}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_3) { + + auto a = NDArrayFactory::create('c', {2, 4, 4}, { + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f, + + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + }); + + auto b = NDArrayFactory::create('c', {2, 4, 1}, { + 4.f, 2.f, 4.f, 2.f, + 4.f, 2.f, 4.f, 2.f + }); + + auto exp = NDArrayFactory::create('c', {2, 4, 1}, { + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f, + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.execute({&a, &b}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_4) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 2.f, 4.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.execute({&a, &b}, {}, {}, {false}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 5.f, 1., -3.f, 3.f, + 0.f, 1.f, 1.f, -1.f, + 0.f, 0.f, 2.f, -9.f, + 0.f, 0.f, 0.f, 4.f + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 5.f, 2.f, 0.f, -3.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + 1.f, 1.f, 1.f, 1.f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.execute({&a, &b}, {}, {}, {false, true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("TriangularSolve with adjoint"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 91ff89d46..ee569a07c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -24,7 +24,7 @@ #include #include #include - +#include using namespace nd4j; @@ -38,6 +38,19 @@ public: } }; +template +class TypedDeclarableOpsTests13 : public testing::Test { +public: + + TypedDeclarableOpsTests13() { + printf("\n"); + fflush(stdout); + } +}; + +typedef ::testing::Types TestingTypes; +TYPED_TEST_CASE(TypedDeclarableOpsTests13, TestingTypes); + TEST_F(DeclarableOpsTests13, test_pow_1) { auto x = NDArrayFactory::create('c', {2, 2}, {2.f, 2.f, 2.f, 2.f}); auto y = NDArrayFactory::create('c', {2}, {3, 3}); @@ -169,7 +182,7 @@ TEST_F(DeclarableOpsTests13, test_or_1) { NDArray z('c', {4}, nd4j::DataType::BOOL); - x.applyPairwiseTransform(pairwise::Or, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Or, y, z); ASSERT_EQ(e, z); } @@ -181,7 +194,7 @@ TEST_F(DeclarableOpsTests13, test_and_1) { auto z = NDArrayFactory::create('c', {4}); - x.applyPairwiseTransform(pairwise::And, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::And, y, z); ASSERT_EQ(e, z); } @@ -193,7 +206,7 @@ TEST_F(DeclarableOpsTests13, test_xor_1) { auto z = NDArrayFactory::create('c', {4}); - x.applyPairwiseTransform(pairwise::Xor, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Xor, y, z); ASSERT_EQ(e, z); } @@ -432,7 +445,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_1) { NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_hue op; - auto results = op.execute({&input, &factor}, {}, {2}); + std::unique_ptr results (op.execute({&input, &factor}, {}, {2})); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -442,17 +455,18 @@ TEST_F(DeclarableOpsTests13, adjustHue_1) { ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); - delete results; + } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_2) { - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); - NDArray exp ('c', {2,2,3}, {4,100,0, 146,220,5, 97,123.8,230, 255,2,164.8}, nd4j::DataType::FLOAT32); + NDArray input('c', { 2,2,3 }, { 0.f,100.f / 255.f,56.f / 255.f, 17.f / 255.f,220.f / 255.f,5.f / 255.f, 150.f / 255.f,97.f / 255.f,230.f / 255.f, 255.f / 255.f,2.f / 255.f,13.f / 255.f }, nd4j::DataType::FLOAT32); + NDArray exp('c', { 2,2,3 }, { 4.f / 255.f,100.f / 255.f,0.f, 146.f / 255.f,220.f / 255.f,5.f / 255.f, 97.f / 255.f,123.8f / 255.f,230.f / 255.f, 255.f / 255.f,2.f / 255.f,164.8f / 255.f }, nd4j::DataType::FLOAT32); + nd4j::ops::adjust_hue op; - auto results = op.execute({&input}, {0.9}, {2}); + std::unique_ptr results(op.execute({&input}, {0.9}, {2})); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -461,7 +475,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_2) { ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); - delete results; + } @@ -472,7 +486,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_3) { NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_hue op; - auto results = op.execute({&input}, {-0.9}, {2}); + std::unique_ptr results(op.execute({&input}, {-0.9}, {2})); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -481,7 +495,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_3) { ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); - delete results; + } //////////////////////////////////////////////////////////////////// @@ -491,7 +505,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_4) { NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_hue op; - auto results = op.execute({&input}, {0.5}, {1}); + std::unique_ptr results(op.execute({&input}, {0.5}, {1})); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -500,7 +514,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_4) { ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); - delete results; + } //////////////////////////////////////////////////////////////////// @@ -510,7 +524,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) { NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_hue op; - auto results = op.execute({&input}, {0.5}, {0}); + std::unique_ptr results(op.execute({&input}, {0.5}, {0})); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -519,7 +533,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) { ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); - delete results; + } //////////////////////////////////////////////////////////////////// @@ -1029,10 +1043,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f, - 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f, - 0.53763f, 0.53763f, 0.53763f, 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f, - 0.53626f, 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, 0.55327f, + auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f, + 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f, + 0.53763f, 0.53763f, 0.53763f, 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f, + 0.53626f, 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, 0.55327f, 0.54484f, 0.54484f, 0.54484f, 0.55379f, 0.55379f, 0.55379f, 0.5625f, 0.5625f, 0.5625f}); auto expClast = NDArrayFactory::create('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f}); @@ -1947,3 +1961,289 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) { } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test1) { + + NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,4}, {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test2) { + + auto input = NDArrayFactory::create('c', {2,3,4}); + auto mean = NDArrayFactory::create('c', {4}); + auto variance = NDArrayFactory::create('c', {4}); + auto gamma = NDArrayFactory::create('c', {4}); + auto beta = NDArrayFactory::create('c', {4}); + + auto expected = NDArrayFactory::create('c', {2,3,4}, {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, 1.16970393f, 1.33940786f, + 1.50911179f, 1.67881572f, 1.84851965f, 2.01822358f, 2.18792751f, 2.35763144f, 2.52733537f, 2.6970393f, 2.86674323f, 3.03644717f, 3.2061511f, 3.37585503f}); + + input.linspace(0.1, 0.1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(1.); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test3) { + + auto input = NDArrayFactory::create('c', {2,3,4}); + auto mean = NDArrayFactory::create('c', {3}, {1.05f, 1.1f, 1.15f}); + auto variance = NDArrayFactory::create('c', {3}, {0.5f, 0.6f, 0.7f}); + auto gamma = NDArrayFactory::create('c', {3}, {1.2f, 1.3f, 1.4f}); + auto beta = NDArrayFactory::create('c', {3}, {0.1f, 0.2f, 0.3f}); + + auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, 0.21633459f, 0.38366541f, + 0.52425983f, 0.69396376f, 0.86366769f, 1.03337162f, 1.20696728f, 1.37479516f, 1.54262304f, 1.71045092f, 1.8896427f, 2.05697351f, 2.22430432f, 2.39163513f}); + + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test4) { + + auto input = NDArrayFactory::create('c', {2,3,4}); + auto mean = NDArrayFactory::create('c', {2,1,4}, {1.05f, 1.1f, 1.15f, 1.2f, 1.25f, 1.3f, 1.35f, 1.4f}); + auto variance = NDArrayFactory::create('c', {2,1,4}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); + auto gamma = NDArrayFactory::create('c', {2,1,4}, {1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f}); + auto beta = NDArrayFactory::create('c', {2,1,4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.66f, 0.7f, 0.8f}); + + auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, 0.21633459f, 0.4f, + 0.58432694f, 0.82999915f, 0.95743373f, 1.14688951f, 1.25894242f, 1.50999575f, 1.64392367f, 1.84066852f, 1.93355791f, 2.18999235f, 2.33041362f, 2.53444754f}); + + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test5) { + + NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,4,2,2}, { 11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, -9.852428f, -10.f, -20.f, + -19.856981f, -19.713963f, -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f, + -12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, -17.425663f, -17.282644f}, nd4j::DataType::FLOAT32); + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test6) { + + NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,2,2,4}, {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, -9.852428f, -20.143019f, 9.57574f, + 20.388447f, -10.442716f, -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, 22.25299f, + -12.213582f, -17.854719f, 6.860477f, 22.874504f, -12.803871f, -17.282644f}, nd4j::DataType::FLOAT32); + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test7) { + + NDArray input1('c', {3,3,15,15}, nd4j::DataType::FLOAT32); + NDArray input2('c', {3,15,15,3}, nd4j::DataType::FLOAT32); + input2.permutei({0,3,1,2}); + + NDArray mean ('c', {3}, {0, 0, 0}, nd4j::DataType::FLOAT32); + NDArray variance('c', {3}, {1, 1, 1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {3}, {1, 1, 1}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {3}, {0, 0, 0}, nd4j::DataType::FLOAT32); + + NDArray out1('c', {3,3,15,15}, nd4j::DataType::FLOAT32); + NDArray out2('c', {3,3,15,15}, nd4j::DataType::FLOAT32); + + input1.linspace(-1012, 1); + input2.assign(input1); + + nd4j::ops::batchnorm op; + + auto res1 = op.execute({&input1, &mean, &variance, &gamma, &beta}, {&out1}, {1e-5}, {1,1,1}, {}); + ASSERT_EQ(ND4J_STATUS_OK, res1); + + auto res2 = op.execute({&input2, &mean, &variance, &gamma, &beta}, {&out2}, {1e-5}, {1,1,1}, {}); + ASSERT_EQ(ND4J_STATUS_OK, res2); + + ASSERT_TRUE(out1.equalsTo(out2)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test8) { + + NDArray input('c', {2,3,4,5}, nd4j::DataType::FLOAT32); + + NDArray mean ('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray variance('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,3,4,5}, {-105.019394, -103.322357, -101.625313, -99.928276, -98.231239, -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, -88.049004, -86.351959, -84.654922, + -82.957886, -81.260841, -79.563805, -77.866768, -76.169724, -74.472687, -72.775650, -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, -62.593414, + -60.896374, -59.199333, -57.502296, -55.805256, -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, -45.623020, -43.925980, -42.228943, -40.531902, + -38.834862, -37.137825, -35.440784, -33.743744, -32.046707, -30.349667, -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, -20.167431, -18.470392, + -16.773354, -15.076314, -13.379274, -11.682236, -9.985196, -8.288157, -6.591118, -4.894078, -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, 5.288157, + 6.985196, 8.682236, 10.379274, 12.076314, 13.773354, 15.470392, 17.167431, 18.864471, 20.561510, 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, 30.743744, + 32.440784, 34.137825, 35.834862, 37.531902, 39.228943, 40.925980, 42.623020, 44.320061, 46.017097, 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, 56.199333, + 57.896374, 59.593414, 61.290451, 62.987488, 64.684532, 66.381569, 68.078606, 69.775650, 71.472687, 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, 81.654922, + 83.351959, 85.049004, 86.746040, 88.443077, 90.140121, 91.837158, 93.534195, 95.231239, 96.928276}, nd4j::DataType::FLOAT32); + + input.linspace(-60, 1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(-1.5); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(expected.isSameShape(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test9) { + + NDArray input('c', {2,3,3,3,3}, nd4j::DataType::FLOAT32); + + NDArray mean ('c', {1,3,3,3,3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {1,3,3,3,3}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {1,3,3,3,3}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {1,3,3,3,3}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,3,3,3,3}, {-138.960175, -137.263138, -135.566101, -133.869064, -132.172028, -130.474976, -128.777954, -127.080902, -125.383865, -123.686829, -121.989784, -120.292747, + -118.595711, -116.898666, -115.201630, -113.504593, -111.807549, -110.110512, -108.413475, -106.716431, -105.019394, -103.322357, -101.625313, -99.928276, + -98.231239, -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, -88.049004, -86.351959, -84.654922, -82.957886, -81.260841, -79.563805, -77.866768, + -76.169724, -74.472687, -72.775650, -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, -62.593414, -60.896374, -59.199333, -57.502296, -55.805256, + -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, -45.623020, -43.925980, -42.228943, -40.531902, -38.834862, -37.137825, -35.440784, -33.743744, + -32.046707, -30.349667, -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, -20.167431, -18.470392, -16.773354, -15.076314, -13.379274, -11.682236, + -9.985196, -8.288157, -6.591118, -4.894078, -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, 5.288157, 6.985196, 8.682236, 10.379274, 12.076314, 13.773354, + 15.470392, 17.167431, 18.864471, 20.561510, 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, 30.743744, 32.440784, 34.137825, 35.834862, 37.531902, 39.228943, + 40.925980, 42.623020, 44.320061, 46.017097, 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, 56.199333, 57.896374, 59.593414, 61.290451, 62.987488, 64.684532, + 66.381569, 68.078606, 69.775650, 71.472687, 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, 81.654922, 83.351959, 85.049004, 86.746040, 88.443077, 90.140121, + 91.837158, 93.534195, 95.231239, 96.928276, 98.625313, 100.322357, 102.019394, 103.716431, 105.413475, 107.110512, 108.807549, 110.504593, 112.201630, 113.898666, + 115.595711, 117.292747, 118.989784, 120.686829, 122.383865, 124.080902, 125.777946, 127.474976, 129.172028, 130.869064, 132.566101, 134.263138}, nd4j::DataType::FLOAT32); + + input.linspace(-80, 1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(-1.5); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3,4}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShape(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index d87acc439..75db5989c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -138,8 +138,8 @@ TEST_F(DeclarableOpsTests15, test_avgpooling_edge_1) { } TEST_F(DeclarableOpsTests15, Test_standarize_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto e = NDArrayFactory::create('c', {5}, {0, 0, 0, 0, 0}); + auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::standardize op; auto result = op.execute({&x}, {&x}, {}, {0}, {}); @@ -293,268 +293,77 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) { .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { - 1.0218375f, - 1.0666375f, - 0.9130375f, - - -0.07396251f, - 0.91843754f, - -0.17496246f, - - 0.47543746f, - 1.2492375f, - 0.55643755f, - - 1.3110375f, - -0.36456245f, - 1.0518374f, - - 0.7824375f, - 0.57523745f, - -0.21656245f, - - 0.0816375f, - -0.2261625f, - 0.40323752f, - - 1.4520376f, - 0.6868375f, - 0.81723756f, - - -0.17576247f, - 0.81423753f, - -0.08656245f, - - - -0.36249164f, - 0.45590833f, - 1.1925083f, - - 0.00650835f, - 1.4861084f, - 1.2079083f, - - 0.05270836f, - 0.37350836f, - 0.94130826f, - - 1.0715083f, - 0.6103083f, - 0.9825083f, - - 0.07370833f, - -0.4518917f, - -0.39889166f, - - -0.3354917f, - 1.2213084f, - 1.0345083f, - - -0.3132917f, - 0.78470826f, - 0.23390833f, - - 0.6943083f, - 0.68170834f, - -0.09989169f, - - - 0.8352709f, - 1.3798709f, - 0.15507084f, - - 0.26607084f, - -0.10792917f, - 1.2302709f, - - 0.6448709f, - -0.29992914f, - 1.3534708f, - - 0.86607087f, - 0.37607086f, - 0.04027084f, - - 0.40087086f, - 0.59507084f, - 0.9416709f, - - 0.53127086f, - -0.01712915f, - 1.4610709f, - - -0.17152917f, - -0.13992918f, - 0.6242708f, - - -0.42192918f, - 0.38387084f, - -0.15752912f, - - - 0.3311833f, - 0.00618333f, - 0.17538333f, - - 0.10418332f, - 0.8365834f, - 0.27098334f, - - 1.2421833f, - -0.1114167f, - 1.0153834f, - - 0.9523833f, - 0.8317833f, - 0.9633833f, - - 0.6501833f, - 0.04258335f, - 0.9999833f, - - -0.40181667f, - 0.11418331f, - 0.47938335f, - - 1.1057833f, - -0.29761666f, - 1.0779834f, - - 0.5243833f, - -0.32181668f, - 1.1833833f, - - - 0.73157084f, - 0.4317708f, - 0.7283708f, - - 1.2297708f, - 0.4307708f, - 0.85377085f, - - 0.05977082f, - -0.09282917f, - 0.33957082f, - - 1.0751709f, - 0.2119708f, - 0.51897085f, - - -0.25302917f, - 1.1723708f, - -0.12562919f, - - 1.1993709f, - 0.5257708f, - 0.40517086f, - - 0.53197086f, - 0.8441708f, - 0.02617085f, - - -0.0208292f, - 0.8711709f, - 0.04137081f, - - - 0.74936247f, - 0.6085625f, - 0.8997625f, - - -0.08743751f, - 0.18576252f, - -0.17563748f, - - 0.5991625f, - -0.0038375f, - 0.07576251f, - - 0.42536253f, - -0.22823751f, - 0.36296248f, - - 0.81456256f, - -0.16183749f, - 0.5161625f, - - -0.21183747f, - 0.7429625f, - 0.6217625f, - - 0.17656249f, - 0.02616251f, - -0.17923748f, - - 1.4659625f, - 0.40016252f, - 0.28356248f, - - - 0.4195791f, - 0.8745791f, - 0.36637908f, - - 0.50597906f, - -0.17942089f, - 0.16917908f, - - 1.0235791f, - 1.3699791f, - -0.11382091f, - - -0.0918209f, - 0.7757791f, - 0.09017909f, - - 1.3807791f, - -0.15202093f, - 1.3875791f, - - -0.1712209f, - 1.3989791f, - 0.43777913f, - - 0.7855791f, - 0.1423791f, - 1.4711791f, - - 0.6455791f, - 0.6211791f, - -0.48062086f, - - - 0.10189578f, - 0.5628958f, - 0.68909574f, - - 0.96649575f, - -0.09370419f, - 1.3466958f, - - 1.4584957f, - 1.3544958f, - -0.3829042f, - - 0.11269578f, - -0.47890422f, - 1.0436958f, - - 0.6128957f, - 0.27209583f, - 0.2714958f, - - 0.21889582f, - 0.08789578f, - 1.1296958f, - - 0.4596958f, - 0.39309582f, - 0.8344958f, - - 0.71149576f, - -0.4799042f, - 0.4880958f + 1.0218375f, 1.0666375f, 0.9130375f, + -0.07396251f, 0.91843754f, -0.17496246f, + 0.47543746f, 1.2492375f, 0.55643755f, + 1.3110375f, -0.36456245f, 1.0518374f, + 0.7824375f, 0.57523745f, -0.21656245f, + 0.0816375f, -0.2261625f, 0.40323752f, + 1.4520376f, 0.6868375f, 0.81723756f, + -0.17576247f, 0.81423753f, -0.08656245f, + + -0.36249164f, 0.45590833f, 1.1925083f, + 0.00650835f, 1.4861084f, 1.2079083f, + 0.05270836f, 0.37350836f, 0.94130826f, + 1.0715083f, 0.6103083f, 0.9825083f, + 0.07370833f, -0.4518917f, -0.39889166f, + -0.3354917f, 1.2213084f, 1.0345083f, + -0.3132917f, 0.78470826f, 0.23390833f, + 0.6943083f, 0.68170834f, -0.09989169f, + + 0.8352709f, 1.3798709f, 0.15507084f, + 0.26607084f, -0.10792917f, 1.2302709f, + 0.6448709f, -0.29992914f, 1.3534708f, + 0.86607087f, 0.37607086f, 0.04027084f, + 0.40087086f, 0.59507084f, 0.9416709f, + 0.53127086f, -0.01712915f, 1.4610709f, + -0.17152917f, -0.13992918f, 0.6242708f, + -0.42192918f, 0.38387084f, -0.15752912f, + + 0.3311833f, 0.00618333f, 0.17538333f, + 0.10418332f, 0.8365834f, 0.27098334f, + 1.2421833f, -0.1114167f, 1.0153834f, + 0.9523833f, 0.8317833f, 0.9633833f, + 0.6501833f, 0.04258335f, 0.9999833f, + -0.40181667f, 0.11418331f, 0.47938335f, + 1.1057833f, -0.29761666f, 1.0779834f, + 0.5243833f, -0.32181668f, 1.1833833f, + + 0.73157084f, 0.4317708f, 0.7283708f, + 1.2297708f, 0.4307708f, 0.85377085f, + 0.05977082f, -0.09282917f, 0.33957082f, + 1.0751709f, 0.2119708f, 0.51897085f, + -0.25302917f, 1.1723708f, -0.12562919f, + 1.1993709f, 0.5257708f, 0.40517086f, + 0.53197086f, 0.8441708f, 0.02617085f, + -0.0208292f, 0.8711709f, 0.04137081f, + + 0.74936247f, 0.6085625f, 0.8997625f, + -0.08743751f, 0.18576252f, -0.17563748f, + 0.5991625f, -0.0038375f, 0.07576251f, + 0.42536253f, -0.22823751f, 0.36296248f, + 0.81456256f, -0.16183749f, 0.5161625f, + -0.21183747f, 0.7429625f, 0.6217625f, + 0.17656249f, 0.02616251f, -0.17923748f, + 1.4659625f, 0.40016252f, 0.28356248f, + + 0.4195791f, 0.8745791f, 0.36637908f, + 0.50597906f, -0.17942089f, 0.16917908f, + 1.0235791f, 1.3699791f, -0.11382091f, + -0.0918209f, 0.7757791f, 0.09017909f, + 1.3807791f, -0.15202093f, 1.3875791f, + -0.1712209f, 1.3989791f, 0.43777913f, + 0.7855791f, 0.1423791f, 1.4711791f, + 0.6455791f, 0.6211791f, -0.48062086f, + + 0.10189578f, 0.5628958f, 0.68909574f, + 0.96649575f, -0.09370419f, 1.3466958f, + 1.4584957f, 1.3544958f, -0.3829042f, + 0.11269578f, -0.47890422f, 1.0436958f, + 0.6128957f, 0.27209583f, 0.2714958f, + 0.21889582f, 0.08789578f, 1.1296958f, + 0.4596958f, 0.39309582f, 0.8344958f, + 0.71149576f, -0.4799042f, 0.4880958f }); nd4j::ops::adjust_contrast op; @@ -587,268 +396,79 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { - 1.0218375 , - 1.0666375 , - 0.9130375 , - - -0.07396251, - 0.91843754, - -0.17496246, - - 0.47543746, - 1.2492375 , - 0.55643755, - - 1.3110375 , - -0.36456245, - 1.0518374 , - - 0.7824375 , - 0.57523745, - -0.21656245, - - 0.0816375 , - -0.2261625 , - 0.40323752, - - 1.4520376 , - 0.6868375 , - 0.81723756, - - -0.17576247, - 0.81423753, - -0.08656245, - - - -0.36249164, - 0.45590833, - 1.1925083 , - - 0.00650835, - 1.4861084 , - 1.2079083 , - - 0.05270836, - 0.37350836, - 0.94130826, - - 1.0715083 , - 0.6103083 , - 0.9825083 , - - 0.07370833, - -0.4518917 , - -0.39889166, - - -0.3354917 , - 1.2213084 , - 1.0345083 , - - -0.3132917 , - 0.78470826, - 0.23390833, - - 0.6943083 , - 0.68170834, - -0.09989169, - - - 0.8352709 , - 1.3798709 , - 0.15507084, - - 0.26607084, - -0.10792917, - 1.2302709 , - - 0.6448709 , - -0.29992914, - 1.3534708 , - - 0.86607087, - 0.37607086, - 0.04027084, - - 0.40087086, - 0.59507084, - 0.9416709 , - - 0.53127086, - -0.01712915, - 1.4610709 , - - -0.17152917, - -0.13992918, - 0.6242708 , - - -0.42192918, - 0.38387084, - -0.15752912, - - - 0.3311833 , - 0.00618333, - 0.17538333, - - 0.10418332, - 0.8365834 , - 0.27098334, - - 1.2421833 , - -0.1114167 , - 1.0153834 , - - 0.9523833 , - 0.8317833 , - 0.9633833 , - - 0.6501833 , - 0.04258335, - 0.9999833 , - - -0.40181667, - 0.11418331, - 0.47938335, - - 1.1057833 , - -0.29761666, - 1.0779834 , - - 0.5243833 , - -0.32181668, - 1.1833833 , - - - 0.73157084, - 0.4317708 , - 0.7283708 , - - 1.2297708 , - 0.4307708 , - 0.85377085, - - 0.05977082, - -0.09282917, - 0.33957082, - - 1.0751709 , - 0.2119708 , - 0.51897085, - - -0.25302917, - 1.1723708 , - -0.12562919, - - 1.1993709 , - 0.5257708 , - 0.40517086, - - 0.53197086, - 0.8441708 , - 0.02617085, - - -0.0208292 , - 0.8711709 , - 0.04137081, - - - 0.74936247, - 0.6085625 , - 0.8997625 , - - -0.08743751, - 0.18576252, - -0.17563748, - - 0.5991625 , - -0.0038375 , - 0.07576251, - - 0.42536253, - -0.22823751, - 0.36296248, - - 0.81456256, - -0.16183749, - 0.5161625 , - - -0.21183747, - 0.7429625 , - 0.6217625 , - - 0.17656249, - 0.02616251, - -0.17923748, - - 1.4659625 , - 0.40016252, - 0.28356248, - - - 0.4195791 , - 0.8745791 , - 0.36637908, - - 0.50597906, - -0.17942089, - 0.16917908, - - 1.0235791 , - 1.3699791 , - -0.11382091, - - -0.0918209 , - 0.7757791 , - 0.09017909, - - 1.3807791 , - -0.15202093, - 1.3875791 , - - -0.1712209 , - 1.3989791 , - 0.43777913, - - 0.7855791 , - 0.1423791 , - 1.4711791 , - - 0.6455791 , - 0.6211791 , - -0.48062086, - - - 0.10189578, - 0.5628958 , - 0.68909574, - - 0.96649575, - -0.09370419, - 1.3466958 , - - 1.4584957 , - 1.3544958 , - -0.3829042 , - - 0.11269578, - -0.47890422, - 1.0436958 , - - 0.6128957 , - 0.27209583, - 0.2714958 , - - 0.21889582, - 0.08789578, - 1.1296958 , - - 0.4596958 , - 0.39309582, - 0.8344958 , - - 0.71149576, - -0.4799042, - 0.4880958 + 1.0218375, 1.0666375 , 0.9130375 , + -0.07396251, 0.91843754, -0.17496246, + 0.47543746, 1.2492375 , 0.55643755, + 1.3110375 , -0.36456245, 1.0518374 , + 0.7824375 , 0.57523745, -0.21656245, + 0.0816375 , -0.2261625 , 0.40323752, + 1.4520376 , 0.6868375 , 0.81723756, + -0.17576247, 0.81423753, -0.08656245, + + -0.36249164, 0.45590833, 1.1925083 , + 0.00650835, 1.4861084 , 1.2079083 , + 0.05270836, 0.37350836, 0.94130826, + 1.0715083 , 0.6103083 , 0.9825083 , + 0.07370833, -0.4518917 , -0.39889166, + -0.3354917 , 1.2213084 , 1.0345083 , + -0.3132917 , 0.78470826, 0.23390833, + 0.6943083 , 0.68170834, -0.09989169, + + 0.8352709 , 1.3798709 , 0.15507084, + 0.26607084, -0.10792917, 1.2302709 , + 0.6448709 , -0.29992914, 1.3534708 , + 0.86607087, 0.37607086, 0.04027084, + 0.40087086, 0.59507084, 0.9416709 , + 0.53127086, -0.01712915, 1.4610709 , + -0.17152917, -0.13992918, 0.6242708 , + -0.42192918, 0.38387084, -0.15752912, + + + 0.3311833 , 0.00618333, 0.17538333, + 0.10418332, 0.8365834 , 0.27098334, + 1.2421833 , -0.1114167 , 1.0153834 , + 0.9523833 , 0.8317833 , 0.9633833 , + 0.6501833 , 0.04258335, 0.9999833 , + -0.40181667, 0.11418331, 0.47938335, + 1.1057833 , -0.29761666, 1.0779834 , + 0.5243833 , -0.32181668, 1.1833833 , + + 0.73157084, 0.4317708 , 0.7283708 , + 1.2297708 , 0.4307708 , 0.85377085, + 0.05977082, -0.09282917, 0.33957082, + 1.0751709 , 0.2119708 , 0.51897085, + -0.25302917, 1.1723708 , -0.12562919, + 1.1993709 , 0.5257708 , 0.40517086, + 0.53197086, 0.8441708 , 0.02617085, + -0.0208292 , 0.8711709 , 0.04137081, + + 0.74936247, 0.6085625 , 0.8997625 , + -0.08743751, 0.18576252, -0.17563748, + 0.5991625 , -0.0038375 , 0.07576251, + 0.42536253, -0.22823751, 0.36296248, + 0.81456256, -0.16183749, 0.5161625 , + -0.21183747, 0.7429625 , 0.6217625 , + 0.17656249, 0.02616251, -0.17923748, + 1.4659625 , 0.40016252, 0.28356248, + + 0.4195791 , 0.8745791 , 0.36637908, + 0.50597906, -0.17942089, 0.16917908, + 1.0235791 , 1.3699791 , -0.11382091, + -0.0918209 , 0.7757791 , 0.09017909, + 1.3807791 , -0.15202093, 1.3875791 , + -0.1712209 , 1.3989791 , 0.43777913, + 0.7855791 , 0.1423791 , 1.4711791 , + 0.6455791 , 0.6211791 , -0.48062086, + + + 0.10189578, 0.5628958 , 0.68909574, + 0.96649575, -0.09370419, 1.3466958 , + 1.4584957 , 1.3544958 , -0.3829042 , + 0.11269578, -0.47890422, 1.0436958 , + 0.6128957 , 0.27209583, 0.2714958 , + 0.21889582, 0.08789578, 1.1296958 , + 0.4596958 , 0.39309582, 0.8344958 , + 0.71149576, -0.4799042, 0.4880958 }); // x.linspace(1.); nd4j::ops::adjust_contrast_v2 op; @@ -917,6 +537,20 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4) { } +TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) { + auto x = NDArrayFactory::create('c', {1, 2}); + auto e = NDArrayFactory::create('c', {1, 2}, {4607182418800017408LL, 4611686018427387904LL}); // as TF 4607182418800017408, 4611686018427387904 + x.linspace(1.); + nd4j::ops::bitcast op; + + auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result->status()); + // e.printIndexedBuffer("Double to int64"); + auto res = result->at(0); + ASSERT_EQ(*res, e); + delete result; +} + TEST_F(DeclarableOpsTests15, Test_BitCast_5) { auto x = NDArrayFactory::create('c', {4, 4}, { @@ -971,21 +605,6 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_7) { delete result; } -TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) { - auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); - auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); - auto b = NDArrayFactory::create('c', {1, 16}); - auto grad = NDArrayFactory::create('c', {4, 16, 64, 64}); - - auto gradI = in.like(); - auto gradW = w.like(); - auto gradB = b.like(); - - nd4j:ops::depthwise_conv2d_bp op; - auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}, {}); - ASSERT_EQ(Status::OK(), status); -} - TEST_F(DeclarableOpsTests15, test_matmul_bp_1) { auto a = NDArrayFactory::create('c', {1, 3}); auto b = NDArrayFactory::create('c', {1, 4}); @@ -1299,3 +918,710 @@ TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) { ASSERT_EQ(true, z.e(0)); } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) { + // rank 1 + NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32); + NDArray expected('c', { 1 }, { 55 }, nd4j::DataType::INT32); + nd4j::ops::rgb_to_grs op; + auto result = op.execute({&rgbs}, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_2) { + // rank 1 + auto rgbs = NDArrayFactory::create('f', { 3 }, { 1, 120, -25 }); + auto expected = NDArrayFactory::create('f', { 1 }, { 67 }); + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_3) { + // rank 2 + NDArray rgbs('c', { 4, 3 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32); + NDArray expected('c', { 4, 1 }, { 41, 105, 101, 101 }, nd4j::DataType::INT32); + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_4) { + + NDArray rgbs('c', { 3, 2 }, {14, 99, 207, 10, 114, 201 }, nd4j::DataType::INT32); + + rgbs.permutei({1,0}); + NDArray expected('c', { 2, 1 }, { 138, 58 }, nd4j::DataType::INT32); + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_5) { + // rank 2 + NDArray rgbs('c', { 3, 4 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32); + NDArray expected('c', { 1, 4 }, { 50, 100, 105, 94 }, nd4j::DataType::INT32); + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {0}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_6) { + // rank 3 + auto rgbs = NDArrayFactory::create('c', { 5,4,3 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + auto expected = NDArrayFactory::create('c', { 5,4,1 }, {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f}); + + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_7) { + // rank 3 + auto rgbs = NDArrayFactory::create('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + auto expected = NDArrayFactory::create('c', { 5,1,4 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f,2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f }); + + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {1}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) { + // rank 3 + auto rgbs = NDArrayFactory::create('c', { 3,5,4 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + try { + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + ASSERT_EQ(Status::THROW(), result->status()); + delete result; + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_9) { + // rank 3 + auto rgbs = NDArrayFactory::create('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f}); + auto expected = NDArrayFactory::create('f', { 2,2,1 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f }); + + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_1) { + // rank 1 + NDArray rgbs('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32); + NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32); + nd4j::ops::rgb_to_yuv op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) { + + NDArray rgbs('c', { 3, 2 }, { 14., 99., 207., 10., 114., 201. }, nd4j::DataType::FLOAT32); + rgbs.permutei({ 1,0 }); + + NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, nd4j::DataType::FLOAT32); + nd4j::ops::rgb_to_yuv op; + + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) { + // rank 2 + NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32); + NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32); + + nd4j::ops::rgb_to_yuv op; + auto result = op.execute({ &rgbs }, {}, { 0 }); + auto output = result->at(0); + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_4) { + // rank 3 + NDArray rgbs('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, nd4j::DataType::FLOAT32); + NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32); + + nd4j::ops::rgb_to_yuv op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) { + // rank 3 + NDArray rgbs('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32); + NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, nd4j::DataType::FLOAT32); + + nd4j::ops::rgb_to_yuv op; + auto result = op.execute({ &rgbs }, {}, { 1 }); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) { + // rank 3 + NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32); + try { + nd4j::ops::rgb_to_yuv op; + auto result = op.execute({ &rgbs }, {}, {}); + ASSERT_EQ(Status::THROW(), result->status()); + delete result; + } + catch (std::exception & e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) { + // rank 3 + NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32); + NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32); + + nd4j::ops::rgb_to_yuv op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) { + // rank 1 + NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32); + NDArray expected('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32); + nd4j::ops::yuv_to_rgb op; + auto result = op.execute({ &yuv }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) { + // rank 1 + NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32); + NDArray expected('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32); + nd4j::ops::yuv_to_rgb op; + auto result = op.execute({ &yuv }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_3) { + // rank 2 + NDArray expected('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32); + NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32); + + nd4j::ops::yuv_to_rgb op; + auto result = op.execute({ &yuv }, {}, { 0 }); + auto output = result->at(0); + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_4) { + // rank 3 + NDArray expected('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, nd4j::DataType::FLOAT32); + NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32); + + nd4j::ops::yuv_to_rgb op; + auto result = op.execute({ &yuv }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) { + // rank 3 + NDArray expected('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32); + NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, nd4j::DataType::FLOAT32); + + nd4j::ops::yuv_to_rgb op; + auto result = op.execute({ &yuv }, {}, { 1 }); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) { + // rank 3 + NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32); + try { + nd4j::ops::yuv_to_rgb op; + auto result = op.execute({ &yuv }, {}, {}); + ASSERT_EQ(Status::THROW(), result->status()); + delete result; + } + catch (std::exception & e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) { + // rank 3 + NDArray expected('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32); + NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, nd4j::DataType::FLOAT32); + + nd4j::ops::yuv_to_rgb op; + auto result = op.execute({ &yuv }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////////////// + +TEST_F(DeclarableOpsTests15, Pow_BP_Test1) { + + // same shape + NDArray x('c', { 2,2,2 }, { 4,3,2,5,7,8,-9,-12 }, nd4j::DataType::FLOAT32); + NDArray y('c', { 2,2,2 }, { 2,3,-2,4,-1,-4,10,8 }, nd4j::DataType::FLOAT32); + + + NDArray dLdz('c', { 2,2,2 }, nd4j::DataType::FLOAT32); + NDArray dLdxExp('c', { 2,2,2 }, { 8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08 }, nd4j::DataType::FLOAT32); + NDArray dLdyExp('c', { 2,2,2 }, { 22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0 }, nd4j::DataType::FLOAT32); + + dLdz.assign(1.0); + + nd4j::ops::Pow_bp op; + auto results = op.execute({ &x, &y, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto* dLdx = results->at(0); + auto* dLdy = results->at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + + delete results; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test2) { + + NDArray x('c', { 1,2,3 }, nd4j::DataType::FLOAT32); + NDArray y('c', { 3,2,1 }, nd4j::DataType::FLOAT32); + NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + + NDArray dLdxExp('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24., 26.4, 28.8 }, nd4j::DataType::FLOAT32); + NDArray dLdyExp('c', { 3,2,1 }, { 13.30843, 33.27106, 53.2337, 73.19634, 93.15898, 113.12162 }, nd4j::DataType::FLOAT32); + + x.assign(4.0); + y.assign(2.0); + dLdz.linspace(0.1, 0.1); + + nd4j::ops::Pow_bp op; + auto results = op.execute({ &x, &y, &dLdz }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto* dLdx = results->at(0); + auto* dLdy = results->at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + + delete results; + +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test3) { + + // y - same shape as dLdz + NDArray xY('c', { 1,2,3 }, nd4j::DataType::FLOAT32); + NDArray yY('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + + NDArray dLdxExpY('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24. , 26.4, 28.8 }, nd4j::DataType::FLOAT32); + NDArray dLdyExpY('c', { 3,2,3 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843, 15.5265 , 17.74457, 19.96264, 22.18071, 24.39878, 26.61685, 28.83492, 31.05299, 33.27106, 35.48914, 37.70721, 39.92528 }, nd4j::DataType::FLOAT32); + NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + + xY.assign(4.0); + yY.assign(2.0); + dLdz.linspace(0.1, 0.1); + + nd4j::ops::Pow_bp op; + auto resultsY = op.execute({ &xY, &yY, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsY->status()); + + auto* dLdxY = resultsY->at(0); + auto* dLdyY = resultsY->at(1); + + ASSERT_TRUE(dLdxExpY.isSameShape(dLdxY)); + ASSERT_TRUE(dLdxExpY.equalsTo(dLdxY)); + ASSERT_TRUE(dLdyExpY.isSameShape(dLdyY)); + ASSERT_TRUE(dLdyExpY.equalsTo(dLdyY)); + + delete resultsY; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test4) { + + // x - same shape ad dLdz + NDArray yX('c', { 1,2,3 }, nd4j::DataType::FLOAT32); + NDArray xX('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + + NDArray dLdxExpX('c', { 3,2,3 }, { 3.2, 6.4, 9.6, 12.8, 16. , 19.2, 22.4, 25.6, 28.8, 32. , 35.2, 38.4, 41.6, 44.8, 48., 51.2, 54.4, 57.6 }, nd4j::DataType::FLOAT32); + NDArray dLdyExpX('c', { 1,2,3 }, { 23.28975, 26.61685, 29.94396, 33.27106, 36.59817, 39.92528 }, nd4j::DataType::FLOAT32); + + NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + dLdz.linspace(0.1, 0.1); + + nd4j::ops::Pow_bp op; + + xX.assign(2.0); + yX.assign(4.0); + + auto resultsX = op.execute({ &xX, &yX, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsX->status()); + + auto* dLdxX = resultsX->at(0); + auto* dLdyX = resultsX->at(1); + + ASSERT_TRUE(dLdxExpX.isSameShape(dLdxX)); + ASSERT_TRUE(dLdxExpX.equalsTo(dLdxX)); + ASSERT_TRUE(dLdyExpX.isSameShape(dLdyX)); + ASSERT_TRUE(dLdyExpX.equalsTo(dLdyX)); + + delete resultsX; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test5) { + + // both single array + NDArray xConst('c', { 1 }, nd4j::DataType::FLOAT32); + NDArray yConst('c', { 1 }, nd4j::DataType::FLOAT32); + NDArray dLdz('c', { 1 }, nd4j::DataType::FLOAT32); + NDArray dLdxExp('c', { 1 }, nd4j::DataType::FLOAT32); + NDArray dLdyExp('c', { 1 }, nd4j::DataType::FLOAT32); + + xConst.assign(3.0); + yConst.assign(4.0); + dLdz.assign(1.0); + + dLdxExp.assign(4.0 * pow(3, 3)); + dLdyExp.assign(pow(3, 4) * log(3)); + + nd4j::ops::Pow_bp op; + auto results = op.execute({ &xConst, &yConst, &dLdz }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto* dLdx = results->at(0); + auto* dLdy = results->at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + + delete results; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test6) { + + // x single array + NDArray xConst('c', { 1 }, nd4j::DataType::FLOAT32); + NDArray y('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); + NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); + + xConst.assign(2.0); + y.assign(4.0); + dLdzC.linspace(0.1, 0.1); + + NDArray dLdxExpXC('c', { 1 }, { 115.2 }, nd4j::DataType::FLOAT32); + NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32); + + nd4j::ops::Pow_bp op; + auto resultsXC = op.execute({ &xConst, &y, &dLdzC }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsXC->status()); + + auto* dLdxXC = resultsXC->at(0); + auto* dLdyXC = resultsXC->at(1); + + ASSERT_TRUE(dLdxExpXC.isSameShape(dLdxXC)); + ASSERT_TRUE(dLdxExpXC.equalsTo(dLdxXC)); + ASSERT_TRUE(dLdyExpXC.isSameShape(dLdyXC)); + ASSERT_TRUE(dLdyExpXC.equalsTo(dLdyXC)); + + delete resultsXC; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test7) { + + // Y - scalar + auto Y = NDArrayFactory::create(2.f); + NDArray x('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); + NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); + + dLdzC.linspace(0.1, 0.1); + x = 4.f; + + NDArray dLdxExpYs('c', { 2, 2, 2 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8, 5.6, 6.4 }, nd4j::DataType::FLOAT32); + + auto dLdyExpYs = NDArrayFactory::create(79.85056f); + + nd4j::ops::Pow_bp op; + auto resultsYs = op.execute({ &x, &Y, &dLdzC }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsYs->status()); + + auto* dLdxY = resultsYs->at(0); + auto* dLdyY = resultsYs->at(1); + + ASSERT_TRUE(dLdxExpYs.isSameShape(dLdxY)); + ASSERT_TRUE(dLdxExpYs.equalsTo(dLdxY)); + ASSERT_TRUE(dLdyExpYs.isSameShape(dLdyY)); + ASSERT_TRUE(dLdyExpYs.equalsTo(dLdyY)); + + delete resultsYs; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test8) { + // both scalars + + auto X = NDArrayFactory::create(4.f); + auto Y = NDArrayFactory::create(2.f); + NDArray dLdz = NDArrayFactory::create(0.1f); + + NDArray dLdxExp = NDArrayFactory::create(2.f*4.f*0.1f); + + NDArray dLdyExp = NDArrayFactory::create(pow(4.f, 2.f) * log(4.f) * 0.1f); + + nd4j::ops::Pow_bp op; + auto results = op.execute({ &X, &Y, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto* dLdx = results->at(0); + auto* dLdy = results->at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + + delete results; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test9) { + + nd4j::ops::Pow_bp op; + // diff shapes + NDArray x('c', { 3,2,1 }, nd4j::DataType::FLOAT32); + NDArray y('c', { 1,2,3 }, nd4j::DataType::FLOAT32); + NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + + NDArray dLdxExp('c', { 3,2,1 }, { 4.8, 12., 19.2, 26.4, 33.6, 40.8 }, nd4j::DataType::FLOAT32); + NDArray dLdyExp('c', { 1,2,3 }, { 46.57949, 53.2337 , 59.88792, 66.54213, 73.19634, 79.85056 }, nd4j::DataType::FLOAT32); + + x.assign(4.0); + y.assign(2.0); + dLdz.linspace(0.1, 0.1); + + auto results = op.execute({ &x, &y, &dLdz }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto* dLdx = results->at(0); + auto* dLdy = results->at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + + delete results; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test10) { + + // diff shapes broadcastable + NDArray yB('c', { 1,2,3,1 }, nd4j::DataType::FLOAT32); + NDArray xB('c', { 2,3,1 }, nd4j::DataType::FLOAT32); + + NDArray dLdyExpB('c', { 1,2,3,1 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843 }, nd4j::DataType::FLOAT32); + NDArray dLdxExpB('c', { 2,3,1 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8 }, nd4j::DataType::FLOAT32); + NDArray dLdzB('c', { 1,2,3,1 }, nd4j::DataType::FLOAT32); + + dLdzB.linspace(0.1, 0.1); + xB.assign(4.0); + yB.assign(2.0); + + nd4j::ops::Pow_bp op; + auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsB->status()); + + auto* dLdxB = resultsB->at(0); + auto* dLdyB = resultsB->at(1); + + ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); + ASSERT_TRUE(dLdxExpB.equalsTo(dLdxB)); + + ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); + ASSERT_TRUE(dLdyExpB.equalsTo(dLdyB)); + + delete resultsB; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test11) { + + NDArray xB('c', { 3,2,1 }, { .4, 3, 5, .8, -9, -12 }, nd4j::DataType::FLOAT32); + NDArray yB('c', { 1,2,3 }, { 3, -2, .4, -4, 10, .8 }, nd4j::DataType::FLOAT32); + + NDArray dLdxExpB('c', { 3,2,1 }, { -5.994056, 39366.191406, 7.508829, -2.223537, -std::numeric_limits::quiet_NaN(), -std::numeric_limits::quiet_NaN() }, nd4j::DataType::FLOAT32); + NDArray dLdyExpB('c', { 1,2,3 }, { 20.11211, -1.119612, -std::numeric_limits::quiet_NaN(), -0.1076, 12974.389648, -std::numeric_limits::quiet_NaN() }, nd4j::DataType::FLOAT32); + + NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, nd4j::DataType::FLOAT32); + + nd4j::ops::Pow_bp op; + auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsB->status()); + auto* dLdxB = resultsB->at(0); + auto* dLdyB = resultsB->at(1); + + ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); + for (int i = 0; i < dLdxB->lengthOf(); ++i) { + if (!nd4j::math::nd4j_isnan(dLdxB->e(i)) && !nd4j::math::nd4j_isnan(dLdxExpB.e(i))) + ASSERT_NEAR(dLdxB->e(i), dLdxExpB.e(i), 0.00001); + } + + ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); + for (int i = 0; i < dLdyB->lengthOf(); ++i) { + if (!nd4j::math::nd4j_isnan(dLdyB->e(i)) && !nd4j::math::nd4j_isnan(dLdyExpB.e(i))) + ASSERT_NEAR(dLdyB->e(i), dLdyExpB.e(i), 0.00001); + } + + delete resultsB; +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index f8bf47e53..f05b8f488 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -15,9 +15,9 @@ ******************************************************************************/ -// -// @author raver119@gmail.com -// + // + // @author raver119@gmail.com + // #include "testlayers.h" #include @@ -40,13 +40,13 @@ public: }; TEST_F(DeclarableOpsTests16, scatter_upd_1) { - auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); + auto x = NDArrayFactory::create('c', { 3 }, { 1.f, 1.f, 1.f }); auto y = NDArrayFactory::create(0); auto w = NDArrayFactory::create(3.0f); - auto e = NDArrayFactory::create('c', {3}, {3.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', { 3 }, { 3.f, 1.f, 1.f }); nd4j::ops::scatter_upd op; - auto result = op.execute({&x, &y, &w}, {}, {}); + auto result = op.execute({ &x, &y, &w }, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -58,15 +58,15 @@ TEST_F(DeclarableOpsTests16, scatter_upd_1) { TEST_F(DeclarableOpsTests16, scatter_upd_2) { - NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32); - NDArray indices('c', {2}, {2,5}, nd4j::DataType::INT32); - NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32); - NDArray e('c', {10, 3}, {1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30}, nd4j::DataType::FLOAT32); + NDArray x('c', { 10, 3 }, nd4j::DataType::FLOAT32); + NDArray indices('c', { 2 }, { 2,5 }, nd4j::DataType::INT32); + NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, nd4j::DataType::FLOAT32); + NDArray e('c', { 10, 3 }, { 1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30 }, nd4j::DataType::FLOAT32); x.linspace(1); nd4j::ops::scatter_upd op; - auto result = op.execute({&x, &indices, &updates}, {}, {}); + auto result = op.execute({ &x, &indices, &updates }, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -78,22 +78,22 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) { TEST_F(DeclarableOpsTests16, scatter_upd_3) { - NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32); - NDArray indices('c', {2}, {20,5}, nd4j::DataType::INT32); - NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32); - NDArray output('c', {10, 3}, nd4j::DataType::FLOAT32); + NDArray x('c', { 10, 3 }, nd4j::DataType::FLOAT32); + NDArray indices('c', { 2 }, { 20,5 }, nd4j::DataType::INT32); + NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, nd4j::DataType::FLOAT32); + NDArray output('c', { 10, 3 }, nd4j::DataType::FLOAT32); nd4j::ops::scatter_upd op; - ASSERT_ANY_THROW(op.execute({&x, &indices, &updates}, {&output}, {}, {}, {true, true})); + ASSERT_ANY_THROW(op.execute({ &x, &indices, &updates }, { &output }, {}, {}, { true, true })); } TEST_F(DeclarableOpsTests16, test_size_dtype_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto x = NDArrayFactory::create('c', { 3 }, { 1, 1, 1 }); auto z = NDArrayFactory::create(0.0f); auto e = NDArrayFactory::create(3.0f); nd4j::ops::size op; - auto status = op.execute({&x}, {&z}, {}, {}, {}); + auto status = op.execute({ &x }, { &z }, {}, {}, {}); ASSERT_EQ(Status::OK(), status); ASSERT_EQ(e, z); @@ -103,7 +103,7 @@ TEST_F(DeclarableOpsTests16, test_empty_noop_1) { auto z = NDArrayFactory::empty(); nd4j::ops::noop op; - auto status = op.execute({}, {&z}, {}, {}, {}); + auto status = op.execute({}, { &z }, {}, {}, {}); ASSERT_EQ(Status::OK(), status); } @@ -120,22 +120,22 @@ TEST_F(DeclarableOpsTests16, test_empty_noop_2) { } TEST_F(DeclarableOpsTests16, test_svd_1) { - auto x = NDArrayFactory::create('c', {3, 3}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f}); - auto z = NDArrayFactory::create('c', {3}); + auto x = NDArrayFactory::create('c', { 3, 3 }, { 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f }); + auto z = NDArrayFactory::create('c', { 3 }); nd4j::ops::svd op; - auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {}); + auto status = op.execute({ &x }, { &z }, {}, { 0, 0, 16 }, {}); ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { - auto x = NDArrayFactory::create({37, 37, 37}); - auto y = NDArrayFactory::create({8723, 8723, 8723}); + auto x = NDArrayFactory::create({ 37, 37, 37 }); + auto y = NDArrayFactory::create({ 8723, 8723, 8723 }); auto e = NDArrayFactory::create(18); nd4j::ops::bits_hamming_distance op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.execute({ &x, &y }, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -146,9 +146,9 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { } TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { - auto input = NDArrayFactory::create('c', {512}); - auto low = NDArrayFactory::create('c', {512}); - auto high = NDArrayFactory::create('c', {512}); + auto input = NDArrayFactory::create('c', { 512 }); + auto low = NDArrayFactory::create('c', { 512 }); + auto high = NDArrayFactory::create('c', { 512 }); auto output = NDArrayFactory::create(0.0f); @@ -157,16 +157,16 @@ TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { high.linspace(1.0); nd4j::ops::knn_mindistance op; - auto result = op.execute({&input, &low, &high}, {&output}, {}, {}, {}); + auto result = op.execute({ &input, &low, &high }, { &output }, {}, {}, {}); ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests16, test_empty_cast_1) { - auto x = NDArrayFactory::create('c', {1, 0, 2}); - auto e = NDArrayFactory::create('c', {1, 0, 2}); + auto x = NDArrayFactory::create('c', { 1, 0, 2 }); + auto e = NDArrayFactory::create('c', { 1, 0, 2 }); nd4j::ops::cast op; - auto result = op.execute({&x}, {}, {10}); + auto result = op.execute({ &x }, {}, { 10 }); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(e, *result->at(0)); @@ -175,10 +175,10 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) { TEST_F(DeclarableOpsTests16, test_range_1) { nd4j::ops::range op; - auto z = NDArrayFactory::create('c', {200}); + auto z = NDArrayFactory::create('c', { 200 }); Context ctx(1); - ctx.setTArguments({-1.0, 1.0, 0.01}); + ctx.setTArguments({ -1.0, 1.0, 0.01 }); ctx.setOutputArray(0, &z); auto status = op.execute(&ctx); @@ -187,9 +187,9 @@ TEST_F(DeclarableOpsTests16, test_range_1) { TEST_F(DeclarableOpsTests16, test_range_2) { nd4j::ops::range op; - auto z = NDArrayFactory::create('c', {200}); + auto z = NDArrayFactory::create('c', { 200 }); - double tArgs[] = {-1.0, 1.0, 0.01}; + double tArgs[] = { -1.0, 1.0, 0.01 }; auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0); shape::printShapeInfoLinear("Result", shapes->at(0)); @@ -199,42 +199,902 @@ TEST_F(DeclarableOpsTests16, test_range_2) { } TEST_F(DeclarableOpsTests16, test_reverse_1) { - std::vector rows = {3, 5, 7, 8, 9, 10, 119, 211}; - std::vector columns = {6, 5, 10, 100, 153, 171, 635}; + std::vector rows = { 3, 5, 7, 8, 9, 10, 119, 211 }; + std::vector columns = { 6, 5, 10, 100, 153, 171, 635 }; for (auto r : rows) { for (auto c : columns) { //nd4j_printf("Trying [%i, %i]\n", r, c); - auto array = NDArrayFactory::create('c', {r, c}); - auto exp = NDArrayFactory::create('c', {r, c}); - auto reversed = NDArrayFactory::create('c', {r, c}); + auto array = NDArrayFactory::create('c', { r, c }); + auto exp = NDArrayFactory::create('c', { r, c }); + auto reversed = NDArrayFactory::create('c', { r, c }); - auto rowOriginal = NDArrayFactory::create('c', {c}); - auto rowReversed = NDArrayFactory::create('c', {c}); + auto rowOriginal = NDArrayFactory::create('c', { c }); + auto rowReversed = NDArrayFactory::create('c', { c }); for (int e = 0; e < c; e++) { - rowOriginal.p(e, (float) e); - rowReversed.p(c - e - 1, (float) e); + rowOriginal.p(e, (float)e); + rowReversed.p(c - e - 1, (float)e); } - auto listI = array.allTensorsAlongDimension({1}); - auto listE = exp.allTensorsAlongDimension({1}); + auto listI = array.allTensorsAlongDimension({ 1 }); + auto listE = exp.allTensorsAlongDimension({ 1 }); for (int e = 0; e < r; e++) { - listI->at(e)->assign(rowOriginal); - listE->at(e)->assign(rowReversed); + listI.at(e)->assign(rowOriginal); + listE.at(e)->assign(rowReversed); } - delete listI; - delete listE; - nd4j::ops::reverse op; - Nd4jLong axis = 1; - auto status = op.execute({&array}, {&reversed}, {}, {axis}, {}); + Nd4jLong axis = 1; + auto status = op.execute({ &array }, { &reversed }, {}, { axis }, {}); ASSERT_EQ(Status::OK(), status); ASSERT_EQ(exp, reversed); } } -} \ No newline at end of file +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) { + /* + test case generated by python colorsys and scaled to suit our needs + from colorsys import * + from random import * + import numpy as np + rgbs = np.random.uniform(0,1, 5*4*3 ).astype('float32').reshape([5,4,3]) + hsvs=np.apply_along_axis(lambda x: np.array(rgb_to_hsv(x[0],x[1],x[2])),2,rgbs) + rgbs.ravel() + hsvs.ravel() + */ + auto rgbs = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, + 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, + 0.54742825f, 0.684074104f, 0.52110225f, 0.761800349f, 0.486593395f, + 0.753103435f, 0.237176552f, 0.263826847f, 0.913557053f, 0.90049392f, + 0.290193319f, 0.46850124f, 0.965541422f, 0.148351923f, 0.674094439f, + 0.524110138f, 0.216262609f, 0.0361763388f, 0.2204483f, 0.279114306f, + 0.3721793f, 0.632020354f, 0.25007084f, 0.823592246f, 0.637001634f, + 0.30433768f, 0.0448598303f, 0.385092884f, 0.366362303f, 0.586083114f, + 0.218390301f, 0.931746006f, 0.978048146f, 0.762684941f, 0.00208298792f, + 0.91390729f, 0.505838513f, 0.875348926f, 0.428009957f, 0.367065936f, + 0.911922634f, 0.270003974f, 0.164243385f, 0.0581932105f, 0.313204288f, + 0.644775152f, 0.437950462f, 0.775881767f, 0.575452209f, 0.946475744f + }); + auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, + 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, + 0.199753001f, 0.684074104f, 0.312434604f, 0.361258626f, 0.761800349f, + 0.991390795f, 0.685067773f, 0.753103435f, 0.163174023f, 0.682347894f, + 0.913557053f, 0.268038541f, 0.84635365f, 0.965541422f, 0.112067183f, + 0.679180562f, 0.674094439f, 0.540247589f, 0.870388806f, 0.279114306f, + 0.280050347f, 0.604331017f, 0.632020354f, 0.106776128f, 0.630475283f, + 0.823592246f, 0.490824632f, 0.883509099f, 0.385092884f, 0.75257351f, + 0.765611768f, 0.931746006f, 0.129888852f, 0.997870266f, 0.978048146f, + 0.849081645f, 0.446510047f, 0.91390729f, 0.685308874f, 0.597481251f, + 0.911922634f, 0.0834472676f, 0.784472764f, 0.270003974f, 0.396037966f, + 0.514242649f, 0.644775152f, 0.756701186f, 0.392005324f, 0.946475744f + }); + + + auto actual = NDArrayFactory::create('c', { 5,4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); +#if 0 + //visual check + rgbs.printBuffer("rgbs "); + actual.printBuffer("HSV "); + expected.printBuffer("exp"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) { + /* + swapped_rgbs=rgbs.swapaxes(1,2).ravel() + swapped_hsvs=hsvs.swapaxes(1,2).ravel() + */ + auto rgbs = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f, + 0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f, + 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f, + 0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f, + 0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f, + 0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f, + 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f, + 0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f, + 0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f, + 0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f + }); + auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f, + 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f, + 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f, + 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f, + 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f, + 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f, + 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f, + 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f, + 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f, + 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f + }); + + + auto actual = NDArrayFactory::create('c', { 5,3,4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) { + + auto rgbs = NDArrayFactory::create('c', { 4, 3 }, { + 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, + 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, + 0.54742825f, 0.684074104f + }); + auto expected = NDArrayFactory::create('c', { 4, 3 }, { + 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, + 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, + 0.199753001f, 0.684074104f + }); + + auto actual = NDArrayFactory::create('c', { 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) { + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + auto expected = NDArrayFactory::create('c', { 3, 4 }, { + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + + auto actual = NDArrayFactory::create('c', { 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) { + auto rgbs = NDArrayFactory::create('c', { 3 }, { + 0.545678377f, 0.725874603f, 0.413571358f + }); + auto expected = NDArrayFactory::create('c', { 3 }, { + 0.262831867f, 0.430244058f, 0.725874603f + }); + + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) { + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + + //get subarray + //get subarray + NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); + NDArray expected = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); + subArrRgbs.reshapei({ 3 }); + expected.reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrRgbs.printShapeInfo("subArrRgbs"); +#endif + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &subArrRgbs); + ctx.setOutputArray(0, &actual); + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) { + + auto hsvs = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f, 0.239250854f, 0.499201417f, 0.862712979f, + 0.0853395388f, 0.0810681432f, 0.226065159f, 0.851340771f, 0.602043271f, + 0.690895379f, 0.971996486f, 0.273846686f, 0.464318275f, 0.194078103f, + 0.219649255f, 0.616706491f, 0.847525477f, 0.653597355f, 0.700065672f, + 0.0299375951f, 0.184475258f, 0.274936169f, 0.196718201f, 0.179381892f, + 0.934476376f, 0.895766437f, 0.52967906f, 0.675635338f, 0.966644645f, + 0.770889699f, 0.556649387f, 0.13426739f, 0.899450243f, 0.817096591f, + 0.150202557f, 0.763557851f, 0.709604502f, 0.741747797f, 0.657703638f, + 0.167678103f, 0.828556478f, 0.615502477f, 0.478080243f, 0.447288662f, + 0.864299297f, 0.129833668f, 0.66402483f, 0.795475543f, 0.561332941f + }); + auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f, 0.675155059f, 0.862712979f, 0.432045438f, + 0.226065159f, 0.21712242f, 0.207738476f, 0.690895379f, 0.274946465f, + 0.645954334f, 0.464318275f, 0.337166255f, 0.358530475f, 0.594427716f, + 0.616706491f, 0.481247369f, 0.700065672f, 0.242504601f, 0.661103036f, + 0.274936169f, 0.233327664f, 0.224217249f, 0.904251479f, 0.934476376f, + 0.766848235f, 0.675635338f, 0.317765447f, 0.54157777f, 0.556649387f, + 0.127534108f, 0.213413864f, 0.817096591f, 0.674227886f, 0.0821588641f, + 0.709604502f, 0.656080596f, 0.167780413f, 0.107076412f, 0.0573956046f, + 0.167678103f, 0.46964643f, 0.183820669f, 0.478080243f, 0.01761852f, + 0.129833668f, 0.0943436049f, 0.114806315f, 0.121884218f, 0.561332941f + }); + + + auto actual = NDArrayFactory::create('c', { 5,4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) { + auto hsvs = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f, 0.239250854f, 0.0853395388f, 0.851340771f, + 0.971996486f, 0.499201417f, 0.0810681432f, 0.602043271f, 0.273846686f, + 0.862712979f, 0.226065159f, 0.690895379f, 0.464318275f, 0.194078103f, + 0.847525477f, 0.0299375951f, 0.196718201f, 0.219649255f, 0.653597355f, + 0.184475258f, 0.179381892f, 0.616706491f, 0.700065672f, 0.274936169f, + 0.934476376f, 0.895766437f, 0.966644645f, 0.13426739f, 0.150202557f, + 0.52967906f, 0.770889699f, 0.899450243f, 0.763557851f, 0.675635338f, + 0.556649387f, 0.817096591f, 0.709604502f, 0.741747797f, 0.828556478f, + 0.447288662f, 0.66402483f, 0.657703638f, 0.615502477f, 0.864299297f, + 0.795475543f, 0.167678103f, 0.478080243f, 0.129833668f, 0.561332941f + }); + auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f, 0.675155059f, 0.226065159f, 0.690895379f, + 0.464318275f, 0.862712979f, 0.21712242f, 0.274946465f, 0.337166255f, + 0.432045438f, 0.207738476f, 0.645954334f, 0.358530475f, 0.594427716f, + 0.700065672f, 0.274936169f, 0.904251479f, 0.616706491f, 0.242504601f, + 0.233327664f, 0.934476376f, 0.481247369f, 0.661103036f, 0.224217249f, + 0.766848235f, 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f, + 0.317765447f, 0.127534108f, 0.674227886f, 0.656080596f, 0.54157777f, + 0.213413864f, 0.0821588641f, 0.167780413f, 0.107076412f, 0.46964643f, + 0.01761852f, 0.114806315f, 0.0573956046f, 0.183820669f, 0.129833668f, + 0.121884218f, 0.167678103f, 0.478080243f, 0.0943436049f, 0.561332941f + }); + auto actual = NDArrayFactory::create('c', { 5,3,4 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) { + auto hsvs = NDArrayFactory::create('c', { 4, 3 }, { + 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f + }); + auto expected = NDArrayFactory::create('c', { 4, 3 }, { + 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f + }); + auto actual = NDArrayFactory::create('c', { 4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) { + auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { + 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f + }); + auto expected = NDArrayFactory::create('c', { 3, 4 }, { + 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f + }); + auto actual = NDArrayFactory::create('c', { 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) { + + auto hsvs = NDArrayFactory::create('c', { 3 }, { + 0.705504596f, 0.793608069f, 0.65870738f + }); + auto expected = NDArrayFactory::create('c', { 3 }, { + 0.257768334f, 0.135951888f, 0.65870738f + }); + + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { + + auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { + 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f + }); + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f + }); + + auto actual = NDArrayFactory::create('c', { 3 }); + //get subarray + NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); + subArrHsvs.reshapei({ 3 }); + NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); + expected.reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrHsvs.printShapeInfo("subArrHsvs"); +#endif + + Context ctx(1); + ctx.setInputArray(0, &subArrHsvs); + ctx.setOutputArray(0, &actual); + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) { + /** + generated using numpy + _rgb_to_yiq_kernel = np.array([[0.299f, 0.59590059f, 0.2115f], + [0.587f, -0.27455667f, -0.52273617f], + [0.114f, -0.32134392f, 0.31119955f]]) + nnrgbs = np.array([random() for x in range(0,3*4*5)],np.float32).reshape([5,4,3]) + out =np.tensordot(nnrgbs,_rgb_to_yiq_kernel,axes=[[len(nnrgbs.shape)-1],[0]]) + + #alternatively you could use just with apply + out_2=np.apply_along_axis(lambda x: _rgb_to_yiq_kernel.T @ x,len(nnrgbs.shape)-1,nnrgbs) + + */ + auto rgb = NDArrayFactory::create('c', { 5, 4 ,3 }, + { + 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, + 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , + 0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f, + 0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f , + 0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f, + 0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f, + 0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f, + 0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f, + 0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f, + 0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f , + 0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f, + 0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f + }); + + auto expected = NDArrayFactory::create('c', { 5, 4 ,3 }, + { + 0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, + 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, + -0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f , + 0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f, + 0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f, + 0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f, + 0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f , + -0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f, + 0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f, + 0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f, + -0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f, + -0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f + }); + + auto actual = NDArrayFactory::create('c', { 5, 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgb); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) { + + auto rgb = NDArrayFactory::create('c', { 5, 3, 4 }, + { + 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, + 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, + 0.48942474f, 0.00158441f, 0.97605824f, 0.00112842f, 0.41196227f, + 0.30888894f, 0.02462568f, 0.99260217f, 0.3095014f , 0.3122602f , + 0.14837205f, 0.9585542f , 0.6620493f , 0.7993488f , 0.86656475f, + 0.72481847f, 0.3573504f , 0.25830218f, 0.5997049f , 0.7835693f , + 0.33301765f, 0.59289205f, 0.9776477f , 0.14649455f, 0.7853056f , + 0.41357264f, 0.5934154f , 0.96197623f, 0.1427159f , 0.10614152f, + 0.72647524f, 0.0720306f , 0.19581454f, 0.26093867f, 0.6623308f , + 0.23853847f, 0.06766324f, 0.9584985f , 0.01258832f, 0.08418505f, + 0.20662387f, 0.50057423f, 0.8160156f , 0.86440504f, 0.4153733f , + 0.08274968f, 0.56506383f, 0.6807802f , 0.76146203f, 0.9521758f + }); + + auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, + { + 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f, 0.32321111f, 0.69227005f, 0.38032767f, + 0.3667803f , 0.52719408f, -0.57987869f, -0.05223263f, -0.15853189f, + 0.2397369f , -0.22032876f, 0.13137188f, 0.15085728f, 0.72258149f, + 0.69337627f, 0.39185397f, 0.47240727f, 0.03757231f, 0.16971045f, + -0.13084008f, -0.1417591f , 0.17403452f, -0.21071186f, 0.145886f , + -0.12659159f, 0.67937788f, 0.35710624f, 0.1653288f , 0.29417614f, + -0.05867803f, 0.47681283f, 0.00953913f, -0.31640032f, -0.04813048f, + 0.24003804f, -0.05111816f, 0.18433114f, 0.54718234f, 0.61018603f, + 0.39241133f, 0.3067938f , -0.39812097f, -0.40592682f, -0.23560742f, + -0.0304029f , -0.24805083f, -0.22219216f, 0.06353694f, 0.35893188f + }); + + auto actual = NDArrayFactory::create('c', { 5, 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgb); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + nd4j::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) { + + auto rgb = NDArrayFactory::create('c', { 4, 3 }, + { + 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, + 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , + 0.98633456f, 0.00158441f + }); + + auto expected = NDArrayFactory::create('c', { 4, 3 }, + { + 0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, + 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, + -0.07432612f, -0.44518381f + }); + + auto actual = NDArrayFactory::create('c', { 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgb); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) { + + auto rgb = NDArrayFactory::create('c', { 3, 4 }, + { + 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, + 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, + 0.48942474f, 0.00158441f + }); + + auto expected = NDArrayFactory::create('c', { 3, 4 }, + { + 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f + }); + + auto actual = NDArrayFactory::create('c', { 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgb); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + nd4j::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) { + + auto rgbs = NDArrayFactory::create('c', { 3 }, + { 0.48055f , 0.80757356f, 0.2564435f }); + auto expected = NDArrayFactory::create('c', { 3 }, + { 0.64696468f, -0.01777124f, -0.24070648f, }); + + + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) { + + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, + { + 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, + 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, + 0.48942474f, 0.00158441f + }); + + auto yiqs = NDArrayFactory::create('c', { 3, 4 }, + { + 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f + }); + + //get subarray + NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); + NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) }); + subArrRgbs.reshapei({ 3 }); + expected.reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrRgbs.printShapeInfo("subArrRgbs"); +#endif + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &subArrRgbs); + ctx.setOutputArray(0, &actual); + nd4j::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) { + + auto yiqs = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, + 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, + -0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f, + 0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f, + 0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f, + 0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f, + 0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f, + -0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f, + -0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f, + 0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f, + 0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f, + 0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f + }); + auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, + 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, + 0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f, + -0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f, + 0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f, + 0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f, + 0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f, + 0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f, + 1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f, + 0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f, + 1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f, + -0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f + }); + auto actual = NDArrayFactory::create('c', { 5, 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &yiqs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) { + + auto yiqs = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, + -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, + 0.145902053f, 0.263960421f, 0.700227439f, 0.130805135f, 0.0276055578f, + 0.716282248f, 0.32434237f, -0.438441873f, -0.179727226f, 0.278215706f, + -0.278446227f, 0.187127829f, 0.305075705f, -0.44586885f, 0.76971364f, + 0.900081575f, 0.387832165f, 0.632930398f, 0.131288841f, -0.0788725987f, + 0.229834676f, 0.0443540029f, -0.141177326f, 0.14756602f, 0.47921446f, + -0.268817365f, 0.0977194682f, 0.946808815f, 0.659476519f, 0.496989518f, + -0.141669706f, -0.52525419f, 0.391066104f, -0.283434421f, -0.140715122f, + -0.106209636f, 0.426448852f, -0.177366048f, 0.715208411f, 0.616444945f, + 0.224696323f, 0.446561724f, -0.496444523f, 0.345852494f, 0.451372236f, + -0.187599331f, 0.189553142f, 0.447739422f, 0.298027098f, -0.448159873f + }); + auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, + -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, + 0.280231822f, 1.91936605f, 0.837427991f, -0.17216571f, 0.0451873479f, + 0.705446224f, 0.792213732f, 0.128957025f, -0.120952621f, 0.929172217f, + -0.133271854f, 0.934955336f, 0.746436225f, -0.351493549f, 0.807577594f, + 0.916293093f, 0.905059196f, 0.508443732f, 0.825371955f, 0.82603058f, + 0.015164554f, 0.794845279f, 0.383812296f, 1.23885956f, 0.950156781f, + 0.12571529f, -0.125074273f, 0.378735409f, 1.2980804f, 0.115916842f, + 0.227326869f, 1.15842402f, 0.277102016f, 0.688879376f, 0.0147000261f, + 1.34712305f, 0.953435072f, 0.508405162f, 0.35829352f, 1.22504294f, + 0.841224629f, -0.0110094378f, 0.727568094f, 0.232589777f, -0.0909671176f, + 0.787642119f, 1.58768577f, 0.996727258f, 0.233051388f, -0.109582274f + }); + auto actual = NDArrayFactory::create('c', { 5, 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &yiqs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + nd4j::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) { + + auto yiqs = NDArrayFactory::create('c', { 4, 3 }, { + 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, + 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, + -0.471601307f, 0.263960421f + }); + auto expected = NDArrayFactory::create('c', { 4, 3 }, { + 0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, + 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, + 0.905021825f, 1.91936605f + }); + auto actual = NDArrayFactory::create('c', { 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &yiqs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) { + + auto yiqs = NDArrayFactory::create('c', { 3, 4 }, { + 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, + -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, + 0.145902053f, 0.263960421f + }); + auto expected = NDArrayFactory::create('c', { 3, 4 }, { + 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, + -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, + 0.280231822f, 1.91936605f + }); + auto actual = NDArrayFactory::create('c', { 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &yiqs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + nd4j::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) { + + auto yiqs = NDArrayFactory::create('c', { 3 }, { + 0.775258899f, -0.288912386f, -0.132725924f + }); + auto expected = NDArrayFactory::create('c', { 3 }, { + 0.416663059f, 0.939747555f, 0.868814286f + }); + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &yiqs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); +#if 0 + actual.printBuffer("actual"); + expected.printBuffer("expected"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) { + + auto yiqs = NDArrayFactory::create('c', { 3, 4 }, { + 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, + -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, + 0.145902053f, 0.263960421f + }); + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, + -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, + 0.280231822f, 1.91936605f + }); + + //get subarray + NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) }); + NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); + subArrYiqs.reshapei({ 3 }); + expected.reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrYiqs.printShapeInfo("subArrYiqs"); +#endif + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &subArrYiqs); + ctx.setOutputArray(0, &actual); + nd4j::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp new file mode 100644 index 000000000..543043ebd --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp @@ -0,0 +1,94 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace nd4j; + + +class DeclarableOpsTests17 : public testing::Test { +public: + + DeclarableOpsTests17() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) { + auto values = NDArrayFactory::create({1.f, 2.f, 3.f}); + auto shape = NDArrayFactory::create({3, 3}); + auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); + auto def = NDArrayFactory::create(0.f); + auto exp = NDArrayFactory::create('c', {3, 3}, {1.f,0.f,0.f, 0.f,2.f,0.f, 0.f,0.f,3.f}); + + + nd4j::ops::compat_sparse_to_dense op; + auto result = op.execute({&ranges, &shape, &values, &def}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + +TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) { + auto values = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"}); + auto shape = NDArrayFactory::create({3, 3}); + auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); + auto def = NDArrayFactory::string("d"); + auto exp = NDArrayFactory::string('c', {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"}); + + + nd4j::ops::compat_sparse_to_dense op; + auto result = op.execute({&ranges, &shape, &values, &def}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + +TEST_F(DeclarableOpsTests17, test_compat_string_split_1) { + auto x = NDArrayFactory::string('c', {2}, {"first string", "second"}); + auto delimiter = NDArrayFactory::string(" "); + + auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); + auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"}); + + nd4j::ops::compat_string_split op; + auto result = op.execute({&x, &delimiter}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_EQ(2, result->size()); + + auto z0 = result->at(0); + auto z1 = result->at(1); + + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp1.isSameShape(z1)); + + ASSERT_EQ(exp0, *z0); + ASSERT_EQ(exp1, *z1); + + delete result; +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp new file mode 100644 index 000000000..93864af8c --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace nd4j; + + +class DeclarableOpsTests18 : public testing::Test { +public: + + DeclarableOpsTests18() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests18, test_bitcast_1) { + auto x = NDArrayFactory::create(0.23028551377579154); + auto z = NDArrayFactory::create(0); + auto e = NDArrayFactory::create(4597464930322771456L); + + nd4j::ops::bitcast op; + auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong) nd4j::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp new file mode 100644 index 000000000..871bfe186 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace nd4j; + + +class DeclarableOpsTests19 : public testing::Test { +public: + + DeclarableOpsTests19() { + printf("\n"); + fflush(stdout); + } +}; \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index a8377b429..e4d0db62c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -47,7 +47,7 @@ TEST_F(DeclarableOpsTests2, gather_1) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -68,7 +68,7 @@ TEST_F(DeclarableOpsTests2, gather_2) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -90,7 +90,7 @@ TEST_F(DeclarableOpsTests2, gather_3) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -110,7 +110,7 @@ TEST_F(DeclarableOpsTests2, gather_4) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -131,7 +131,7 @@ TEST_F(DeclarableOpsTests2, gather_5) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -153,7 +153,7 @@ TEST_F(DeclarableOpsTests2, gather_6) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -175,7 +175,7 @@ TEST_F(DeclarableOpsTests2, gather_7) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -197,7 +197,7 @@ TEST_F(DeclarableOpsTests2, gather_8) { // output->printShapeInfo(); // output->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -300,7 +300,7 @@ TEST_F(DeclarableOpsTests2, gather_13) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -440,7 +440,7 @@ TEST_F(DeclarableOpsTests2, Test_Squeeze_1) { TEST_F(DeclarableOpsTests2, Test_Squeeze_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); x.linspace(1); - auto exp = x.dup(); + auto exp = new NDArray(x.dup()); nd4j::ops::squeeze op; auto result = op.execute({&x}, {}, {}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 5322a0a6d..dacfac127 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { auto result0 = op.execute({&x}, {0.}, {}); auto z0 = result0->at(0); - auto exp0 = x.reduceAlongDims(reduce::NormFrobenius, empty, false, false); + auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); ASSERT_TRUE(exp0.isSameShape(z0)); ASSERT_TRUE(exp0.equalsTo(z0)); @@ -201,7 +201,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { ASSERT_EQ(result1->status(), ND4J_STATUS_OK); auto z1 = result1->at(0); // z1->printIndexedBuffer("Z1"); - auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false); + auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); // exp1.printIndexedBuffer("EXP1"); // z1->printShapeInfo("Z1 shape"); // exp1.printShapeInfo("EXP1 shape"); @@ -213,7 +213,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { auto result4 = op.execute({&x}, {4.}, {1}); auto z4 = result4->at(0); - auto exp4= x.reduceAlongDims(reduce::NormMax, dims, false, false); + auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); ASSERT_TRUE(exp4.isSameShape(z4)); ASSERT_TRUE(exp4.equalsTo(z4)); @@ -233,7 +233,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { auto result0 = op.execute({&x}, {0}, {}); auto z0 = result0->at(0); - auto exp0 = x.reduceAlongDims(reduce::NormFrobenius, empty, false, false); + auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); ASSERT_TRUE(exp0.isSameShape(z0)); ASSERT_TRUE(exp0.equalsTo(z0)); @@ -242,7 +242,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { auto result1 = op.execute({&x, &axis}, {1}, {}); auto z1 = result1->at(0); - auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false); + auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); ASSERT_TRUE(exp1.isSameShape(z1)); ASSERT_TRUE(exp1.equalsTo(z1)); @@ -251,7 +251,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { auto result4 = op.execute({&x, &axis}, {4}, {}); auto z4 = result4->at(0); - auto exp4= x.reduceAlongDims(reduce::NormMax, dims, false, false); + auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); ASSERT_TRUE(exp4.isSameShape(z4)); ASSERT_TRUE(exp4.equalsTo(z4)); @@ -329,21 +329,21 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) { x.linspace(100.); - auto xNorm1 = x.reduceAlongDims(reduce::Norm2, {1}, true); + auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); x /= xNorm1; - xNorm1 = x.reduceAlongDims(reduce::Norm2,{1}, true); + xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true); ASSERT_TRUE(unities.isSameShape(xNorm1)); ASSERT_TRUE(unities.equalsTo(xNorm1)); x *= scale; - xNorm1 = x.reduceAlongDims(reduce::Norm2, {1}, true); + xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); nd4j::ops::clipbynorm op; auto result = op.execute({&x}, {1.0}, {1}, {}, false, nd4j::DataType::DOUBLE); auto z = result->at(0); - auto zNorm1 = z->reduceAlongDims(reduce::Norm2, {1}, true); + auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true); auto exp = NDArrayFactory::create('c', {3, 1}, {1., 1., xNorm1.e(2)}); ASSERT_TRUE(exp.isSameShape(&zNorm1)); @@ -2432,17 +2432,11 @@ TEST_F(DeclarableOpsTests3, svd_test6) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test7) { - auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2 - ,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 - ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17 - ,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 - ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14 - ,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16 - ,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); - auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, - 38.18412, 31.52287, 23.52755, 11.79484, 1.90195, - 39.34498, 32.54861, 17.52492, 7.03003, 2.2399, - 44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); + auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 + ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 + ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); + auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031,38.18412, 31.52287, 23.52755, 11.79484, 1.90195, + 39.34498, 32.54861, 17.52492, 7.03003, 2.2399,44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {0, 0, 16}); @@ -2623,75 +2617,25 @@ TEST_F(DeclarableOpsTests3, svd_test9) { 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, - 0.26329, 0.3079 , 0.38582, 0.77696, 0.28872, - 0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, - -0.61335, 0.10076, 0.01381, 0.40922, -0.66783, - -0.10577, 0.93946, -0.0871 , -0.31058, 0.04677, - 0.52823, 0.31163, -0.78777, 0.02322, -0.05234, - -0.23942, -0.45801, -0.34248, 0.71286, 0.32778, - 0.26147, 0.60409, 0.39933, 0.46862, 0.43318, - 0.62118, -0.37993, 0.30992, 0.34537, -0.50444, - 0.45763, -0.42877, 0.08128, -0.3904 , 0.66912, - -0.05428, 0.53632, 0.19774, -0.32198, 0.75276, - -0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, - -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989, - -0.24459, 0.10463, -0.27652, 0.85595, 0.34657, - 0.50772, 0.00757, -0.82374, -0.18941, 0.16658, - 0.49473, -0.39923, -0.20758, 0.74339, -0.01213, - -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492, - 0.68875, 0.1822 , -0.08046, -0.39238, -0.57619, - 0.34555, 0.12488, -0.50703, -0.29269, 0.72267, - -0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); + auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, + -0.61335, 0.10076, 0.01381, 0.40922, -0.66783,-0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,0.52823, 0.31163, -0.78777, 0.02322, -0.05234, + -0.23942, -0.45801, -0.34248, 0.71286, 0.32778,0.26147, 0.60409, 0.39933, 0.46862, 0.43318,0.62118, -0.37993, 0.30992, 0.34537, -0.50444, + 0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, + -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,0.50772, 0.00757, -0.82374, -0.18941, 0.16658, 0.49473, -0.39923, -0.20758, 0.74339, -0.01213, + -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,0.34555, 0.12488, -0.50703, -0.29269, 0.72267,-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); - auto expV= NDArrayFactory::create('c', {2,2,6,6}, {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, - -1.10690000e-01, 1.37280000e-01, - 2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03, - -1.00090000e-01, 9.35890000e-01, - -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01, - 6.70320000e-01, 2.10040000e-01, - 1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01, - -4.38680000e-01, 1.83200000e-02, - -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, - -2.10060000e-01, 2.41550000e-01, - -4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02, - -5.40210000e-01, -4.97000000e-02, - -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01, - -4.21850000e-01, 4.00490000e-01, - 1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01, - -2.06580000e-01, 7.68890000e-01, - -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01, - 5.88210000e-01, 7.12900000e-02, - 2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01, - 1.51570000e-01, 6.02100000e-02, - 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02, - -5.79750000e-01, -2.92870000e-01, - 4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01, - 2.72560000e-01, 3.92350000e-01, - -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01, - -1.17970000e-01, -4.08100000e-02, - 4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01, - 1.93050000e-01, -6.83340000e-01, - 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01, - 2.37500000e-02, 5.78250000e-01, - -6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02, - -9.19220000e-01, -2.15420000e-01, - 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02, - -3.19580000e-01, 2.92020000e-01, - 2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01, - 3.39100000e-02, 2.55590000e-01, - -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01, - 4.98470000e-01, -3.65370000e-01, - 6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02, - -7.30460000e-01, -1.09390000e-01, - -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01, - 5.20000000e-04, 1.90420000e-01, - 2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01, - -3.57600000e-02, -8.60450000e-01, - 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, - -4.39400000e-02, 2.17750000e-01, - -6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01, - -4.63400000e-01, -1.74620000e-01}); + auto expV= NDArrayFactory::create('c', {2,2,6,6}, {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,-1.10690000e-01, 1.37280000e-01,2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,-1.00090000e-01, 9.35890000e-01, + -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,6.70320000e-01, 2.10040000e-01,1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,-4.38680000e-01, 1.83200000e-02, + -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01,-2.10060000e-01, 2.41550000e-01,-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,-5.40210000e-01, -4.97000000e-02, + -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,-4.21850000e-01, 4.00490000e-01,1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,-2.06580000e-01, 7.68890000e-01, + -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,5.88210000e-01, 7.12900000e-02,2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01,1.51570000e-01, 6.02100000e-02, + 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,-5.79750000e-01, -2.92870000e-01,4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,2.72560000e-01, 3.92350000e-01, + -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,-1.17970000e-01, -4.08100000e-02,4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,1.93050000e-01, -6.83340000e-01, + 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,2.37500000e-02, 5.78250000e-01,-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,-9.19220000e-01, -2.15420000e-01, + 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,-3.19580000e-01, 2.92020000e-01,2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,3.39100000e-02, 2.55590000e-01, + -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,4.98470000e-01, -3.65370000e-01,6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,-7.30460000e-01, -1.09390000e-01, + -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,5.20000000e-04, 1.90420000e-01,2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,-3.57600000e-02, -8.60450000e-01, + 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,-4.39400000e-02, 2.17750000e-01,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01, -1.74620000e-01}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {1, 1, 16}); @@ -2736,75 +2680,21 @@ TEST_F(DeclarableOpsTests3, svd_test10) { 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, - 0.26329, 0.3079 , 0.38582, 0.77696, 0.28872, - 0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, - -0.61335, 0.10076, 0.01381, 0.40922, -0.66783, - -0.10577, 0.93946, -0.0871 , -0.31058, 0.04677, - 0.52823, 0.31163, -0.78777, 0.02322, -0.05234, - -0.23942, -0.45801, -0.34248, 0.71286, 0.32778, - 0.26147, 0.60409, 0.39933, 0.46862, 0.43318, - 0.62118, -0.37993, 0.30992, 0.34537, -0.50444, - 0.45763, -0.42877, 0.08128, -0.3904 , 0.66912, - -0.05428, 0.53632, 0.19774, -0.32198, 0.75276, - -0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, - -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989, - -0.24459, 0.10463, -0.27652, 0.85595, 0.34657, - 0.50772, 0.00757, -0.82374, -0.18941, 0.16658, - 0.49473, -0.39923, -0.20758, 0.74339, -0.01213, - -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492, - 0.68875, 0.1822 , -0.08046, -0.39238, -0.57619, - 0.34555, 0.12488, -0.50703, -0.29269, 0.72267, - -0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); + auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,0.03076, 0.03015, -0.9128 , 0.36387, 0.18039,-0.61335, 0.10076, 0.01381, 0.40922, -0.66783, + -0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,0.52823, 0.31163, -0.78777, 0.02322, -0.05234,-0.23942, -0.45801, -0.34248, 0.71286, 0.32778,0.26147, 0.60409, 0.39933, 0.46862, 0.43318, + 0.62118, -0.37993, 0.30992, 0.34537, -0.50444,0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, + -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,0.50772, 0.00757, -0.82374, -0.18941, 0.16658,0.49473, -0.39923, -0.20758, 0.74339, -0.01213, + -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,0.34555, 0.12488, -0.50703, -0.29269, 0.72267,-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); - auto expV= NDArrayFactory::create('c', {2,2,6,5}, { -4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, - -1.10690000e-01, - 2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03, - -1.00090000e-01, - -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01, - 6.70320000e-01, - 1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01, - -4.38680000e-01, - -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, - -2.10060000e-01, - -4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02, - -5.40210000e-01, - -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01, - -4.21850000e-01, - 1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01, - -2.06580000e-01, - -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01, - 5.88210000e-01, - 2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01, - 1.51570000e-01, - 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02, - -5.79750000e-01, - 4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01, - 2.72560000e-01, - -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01, - -1.17970000e-01, - 4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01, - 1.93050000e-01, - 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01, - 2.37500000e-02, - -6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02, - -9.19220000e-01, - 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02, - -3.19580000e-01, - 2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01, - 3.39100000e-02, - -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01, - 4.98470000e-01, - 6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02, - -7.30460000e-01, - -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01, - 5.20000000e-04, - 2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01, - -3.57600000e-02, - 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, - -4.39400000e-02, - -6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01, - -4.63400000e-01}); + auto expV= NDArrayFactory::create('c', {2,2,6,5}, { -4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,-1.10690000e-01,2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,-1.00090000e-01, + -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,6.70320000e-01,1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,-4.38680000e-01,-5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, + -2.10060000e-01,-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,-5.40210000e-01,-6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,-4.21850000e-01, + 1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,-2.06580000e-01,-4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,5.88210000e-01,2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01, + 1.51570000e-01,1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,-5.79750000e-01,4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,2.72560000e-01, + -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,-1.17970000e-01,4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,1.93050000e-01,8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,2.37500000e-02,-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,-9.19220000e-01, + 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,-3.19580000e-01,2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,3.39100000e-02,-4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,4.98470000e-01,6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,-7.30460000e-01, + -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,5.20000000e-04,2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,-3.57600000e-02,1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, + -4.39400000e-02,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {0, 1, 16}); @@ -2865,8 +2755,36 @@ TEST_F(DeclarableOpsTests3, svd_test11) { ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expS.equalsTo(s)); - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); + + if(nd4j::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); + } + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test12) { + + NDArray x('c', {4,3}, {1.7787856,0.80119777,0.72437465,0.23089433,1.7271413,0.18039072,0.50563407,0.89252293,1.5461209,0.92336726,0.085571885,0.79378015}); + NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371}); + + nd4j::ops::svd op; + auto results = op.execute({&x}, {}, {1, 0, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto s = results->at(0); + + ASSERT_TRUE(expS.equalsTo(s)); + ASSERT_TRUE(expS.isSameShape(s)); delete results; } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index c30ad5f89..6d85feec1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -423,7 +423,7 @@ TEST_F(DeclarableOpsTests5, Log1p_test1) { // auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); // auto exp = NDArrayFactory::create('c', {3,3}); nd4j::ops::Log1p op; - y.applyTransform(nd4j::transform::Log, nullptr, nullptr); + y.applyTransform(nd4j::transform::Log, y); auto result = op.execute({&matrix}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2737,7 +2737,7 @@ TEST_F(DeclarableOpsTests5, ELU_1) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.}); auto res = NDArrayFactory::create('c', {2, 2, 2}); - input.applyScalar(nd4j::scalar::ELU, 1.f, &res); + input.applyScalar(nd4j::scalar::ELU, 1.f, res); ASSERT_TRUE(res.equalsTo(&exp)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 67cd56d5e..c52191b8a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -139,7 +139,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { auto ones = onesRes->at(0); *ones *= 10; - auto onesD = ones->dup(); + auto onesD = new NDArray(ones->dup()); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, onesD); @@ -1577,31 +1577,31 @@ TEST_F(DeclarableOpsTests6, LogDet_3) { TEST_F(DeclarableOpsTests6, MatrixInverse_1) { auto x = NDArrayFactory::create('c', {2, 5, 5}, { - 2.f, 4.f, 60.f, 8.f, 10.f, - 0.f, 1.f, 2.f, 3.f, 4.f, - 0.f, 0.f, 2.f, 4.f, 6.f, - 0.f, 0.f, 0.f, 1.f, 2.f, - 0.f, 0.f, 0.f, 0.f, 4.f, + 2.f, 4.f, 60.f, 8.f, 10.f, + 0.f, 1.f, 2.f, 3.f, 4.f, + 0.f, 0.f, 2.f, 4.f, 6.f, + 0.f, 0.f, 0.f, 1.f, 2.f, + 0.f, 0.f, 0.f, 0.f, 4.f, - 1.f, 0.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, 0.f, - 30.f, 2.f, 1.f, 0.f, 0.f, - 4.f, 3.f, 2.f, 1.f, 0.f, + 1.f, 0.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, 0.f, + 30.f, 2.f, 1.f, 0.f, 0.f, + 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f }); auto exp = NDArrayFactory::create('c', {2, 5, 5}, { - 0.5f, -2.0f, -13.0f, 54.0f, -6.75f, - 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, + 0.5f, -2.0f, -13.0f, 54.0f, -6.75f, + 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, - 0.f, 0.f, 0.f, 0.f, 0.25f, + 0.f, 0.f, 0.f, 0.f, 0.25f, - 1.0f, 0.0f, 0.0f, 0.0f, 0.f, - -2.0f, 1.0f, 0.f, 0.f, 0.f, + 1.0f, 0.0f, 0.0f, 0.0f, 0.f, + -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, - -27.0f, 0.0f, 1.0f, -2.0f, 1.f, + -27.0f, 0.0f, 1.0f, -2.0f, 1.f, }); nd4j::ops::matrix_inverse op; @@ -1891,10 +1891,8 @@ TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) { std::vector dims = {0, 1}; - auto z = x.applyReduce3(reduce3::CosineSimilarity, &y, dims, nullptr); - ASSERT_TRUE(z != nullptr); - - delete z; + auto z = x.applyReduce3(reduce3::CosineSimilarity, y, dims); + ASSERT_TRUE(&z != nullptr); } /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 220191011..ffb847dbd 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -3148,14 +3148,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_13) { auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); -// output->printShapeInfo("Output shape"); -// output->printBuffer("Output"); -// exp.printBuffer("Expect"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result->at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -3240,10 +3233,6 @@ auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12 }); - -// 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, -// 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32 -// 21.41, 21.42, 22.11, 22.12 // ---------------------------------------------------------------- nd4j::ops::roll op; @@ -3269,10 +3258,6 @@ auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12 }); - -// 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, -// 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32 -// 21.41, 21.42, 22.11, 22.12 // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; @@ -3518,6 +3503,27 @@ TEST_F(DeclarableOpsTests7, TestRoll_14) { delete result; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_15) { + auto x = NDArrayFactory::create({0.7788f, 0.8012f, 0.7244f, 0.2309f }); + auto shift = NDArrayFactory::create(2); + auto axis = NDArrayFactory::create(0); + + auto exp = NDArrayFactory::create({0.7244f, 0.2309f, 0.7788f, 0.8012f }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + + auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); +// out->printIndexedBuffer("Output 15"); +// exp.printIndexedBuffer("Expect 15"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index ebe1f8e18..ef495142d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -3029,13 +3029,13 @@ TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { // auto expDeviance = NDArrayFactory::create('c', {10, 10}); auto squared = NDArrayFactory::create('c', {10, 10}); - data.applyTransform(transform::Square, &squared, nullptr); + data.applyTransform(transform::Square, squared); auto ssSquared = squared.reduceAlongDimension(reduce::Sum, {0}); // ssSquared->printBuffer("Sum squared"); // squared.printBuffer("Squared"); nd4j::ops::normalize_moments op; - auto results = op.execute({&counts, means, ssSquared}, {0.0}, {0}); - (*means) /= counts; + auto results = op.execute({&counts, &means, &ssSquared}, {0.0}, {0}); + means /= counts; // nd4j::ops::normalize_moments op; // auto results = op.execute({&counts, means, deviance}, {0.0}, {}); @@ -3049,13 +3049,11 @@ TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { // outputDeviance->printIndexedBuffer("Variance"); // deviance.printIndexedBuffer("Expected"); // means->printIndexedBuffer("Expected means"); - ASSERT_TRUE(means->isSameShape(outputMeans)); - ASSERT_TRUE(means->equalsTo(outputMeans)); + ASSERT_TRUE(means.isSameShape(outputMeans)); + ASSERT_TRUE(means.equalsTo(outputMeans)); ASSERT_TRUE(deviance.isSameShape(outputDeviance)); ASSERT_TRUE(deviance.equalsTo(outputDeviance)); - delete means; //delete deviance; - delete ssSquared; // ASSERT_TRUE(expMeans.isSameShape(outputMeans)); // ASSERT_TRUE(expMeans.equalsTo(outputMeans)); // ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); @@ -3636,60 +3634,60 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_2) { auto x = NDArrayFactory::create( 'c', {3, 3, 5, 5}); x.linspace(1); - auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}, { 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, - 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, - 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, - 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, - 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}, { 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, - 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, - 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, - 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, - 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, - 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, - 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, - 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, - 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, - 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, - 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, - 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, - 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, - 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, - 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, - 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, - 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, - 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, - 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, - 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, - 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, - 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, - 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, - 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, - 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, - 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, - 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, - 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, - 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, - 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, - 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, - 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, - 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, - 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, - 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, - 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, - 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, - 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, - 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, - 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f }); // auto exp = NDArrayFactory::create('c', {3,3,5,5}, { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index dfbfc90a8..6df52fb54 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -228,7 +228,7 @@ TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) { auto e = NDArrayFactory::create('c', {2, 2}, {2.0, 3.0, 4.0, 5.0}); auto z = NDArrayFactory::create('c', {2, 2}, {0.0, 0.0, 0.0, 0.0}); - x.applyScalar(scalar::Add, 1.0, &z); + x.applyScalar(scalar::Add, 1.0, z); ASSERT_EQ(e, z); } @@ -634,10 +634,7 @@ TEST_F(DeclarableOpsTests9, concat_test18) { for (int e = 0; e < 2000; e++) { auto row = z.tensorAlongDimension(e, {1}); - - ASSERT_NEAR((float) e, row->e(0), 1e-5f); - - delete row; + ASSERT_NEAR((float) e, row.e(0), 1e-5f); } } @@ -1684,7 +1681,7 @@ TEST_F(DeclarableOpsTests9, test_broadcast_bool_1) { auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); std::vector dims = {0, 2, 3, 4}; - x.applyBroadcast(broadcast::LessThan, dims, &y, &z, nullptr); + x.applyBroadcast(broadcast::LessThan, dims, y, z); } TEST_F(DeclarableOpsTests9, test_broadcast_bool_2) { @@ -1697,7 +1694,7 @@ TEST_F(DeclarableOpsTests9, test_broadcast_bool_2) { auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); std::vector dims = {0, 2, 3, 4}; - x.applyBroadcast(broadcast::LessThan, dims, &y, &z, nullptr); + x.applyBroadcast(broadcast::LessThan, dims, y, z); } @@ -1746,7 +1743,7 @@ TEST_F(DeclarableOpsTests9, clipbynorm_test12) { auto colVect = NDArrayFactory::create('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1}); auto expect = NDArrayFactory::create('c', {bS, nOut}); - auto norm2 = x.reduceAlongDims(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut] + auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut] auto y = ( (x / norm2) * clip) * colVect ; auto temp = (x / norm2) * clip; @@ -2927,13 +2924,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -2970,13 +2967,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3012,13 +3009,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3051,13 +3048,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3092,13 +3089,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3133,13 +3130,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3179,13 +3176,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) { // dLdI->printBuffer(); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3224,13 +3221,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) { // dLdI->printBuffer(); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index 8ae123260..12069c67e 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -197,7 +197,7 @@ TEST_F(EmptyTests, Test_Reshape_3) { TEST_F(EmptyTests, Test_dup_1) { auto empty = NDArrayFactory::empty(); - auto dup = empty.dup(); + auto dup = new NDArray(empty.dup()); ASSERT_TRUE(dup->isEmpty()); ASSERT_EQ(empty, *dup); @@ -286,4 +286,72 @@ TEST_F(EmptyTests, test_shaped_empty_4) { ASSERT_TRUE(array.isEmpty()); ASSERT_EQ(1, array.rankOf()); ASSERT_EQ(shapeOf, array.getShapeAsVector()); -} \ No newline at end of file +} + +TEST_F(EmptyTests, test_empty_reshape_1) { + /* + INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0); + INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2); + + INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0]; + INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0]; + INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0]; + + assertArrayEquals(new long[]{2, 0, 0}, out0.shape()); + assertArrayEquals(new long[]{0, 1}, out1.shape()); + assertArrayEquals(new long[]{10, 0}, out2.shape()); + */ + auto x0 = NDArrayFactory::create('c', {2, 0}); + auto x1 = NDArrayFactory::create('c', {0, 1, 2}); + + auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); + auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); + + auto e0 = NDArrayFactory::create('c', {2, 0, 0}); + auto e1 = NDArrayFactory::create('c', {0, 1}); + + nd4j::ops::reshape op; + auto result0 = op.execute({&x0, &shape0}, {}, {}); + ASSERT_EQ(Status::OK(), result0->status()); + auto z0 = result0->at(0); + ASSERT_EQ(e0, *z0); + + auto result1 = op.execute({&x1, &shape1}, {}, {}); + ASSERT_EQ(Status::OK(), result1->status()); + auto z1 = result1->at(0); + ASSERT_EQ(e1, *z1); + + delete result0; + delete result1; +} + + +TEST_F(EmptyTests, test_empty_matmul_1) { + auto x = NDArrayFactory::create('c', {0, 1}); + auto y = NDArrayFactory::create('c', {1, 0}); + auto e = NDArrayFactory::create('c', {0, 0}); + + nd4j::ops::matmul op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(EmptyTests, test_empty_matmul_2) { + auto x = NDArrayFactory::create('c', {1, 0, 4}); + auto y = NDArrayFactory::create('c', {1, 4, 0}); + auto e = NDArrayFactory::create('c', {1, 0, 0}); + + nd4j::ops::matmul op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(e, *z); + + delete result; +} diff --git a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp index 085127e74..0bf9a1eb7 100644 --- a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp @@ -69,8 +69,7 @@ TEST_F(HelpersTests1, evalHHmatrix_test1) { auto exp = NDArrayFactory::create('c', {4,4}, {-0.629253, -0.764093, -0.13484, -0.0449467, -0.764093, 0.641653, -0.0632377, -0.0210792, -0.13484,-0.0632377, 0.98884,-0.00371987, -0.0449467,-0.0210792,-0.00371987, 0.99876}); auto result = ops::helpers::Householder::evalHHmatrix(x); - - ASSERT_TRUE(result.isSameShapeStrict(&exp)); + ASSERT_TRUE(result.isSameShape(&exp)); ASSERT_TRUE(result.equalsTo(&exp)); } @@ -86,7 +85,7 @@ TEST_F(HelpersTests1, evalHHmatrix_test2) { auto result = ops::helpers::Householder::evalHHmatrix(x); - ASSERT_TRUE(result.isSameShapeStrict(&exp)); + ASSERT_TRUE(result.isSameShape(&exp)); ASSERT_TRUE(result.equalsTo(&exp)); } @@ -109,7 +108,7 @@ TEST_F(HelpersTests1, evalHHmatrixData_test1) { ASSERT_NEAR(normX, normXExpected, 1e-5); ASSERT_NEAR(coeff, coeffExpected, 1e-5); - ASSERT_TRUE(tail.isSameShapeStrict(&expTail)); + ASSERT_TRUE(tail.isSameShapeStrict(expTail)); ASSERT_TRUE(tail.equalsTo(&expTail)); } @@ -128,7 +127,7 @@ TEST_F(HelpersTests1, Householder_mulLeft_test1) { ops::helpers::Householder::mulLeft(x, tail, 0.1); // expTail.printShapeInfo(); - ASSERT_TRUE(x.isSameShapeStrict(&exp)); + ASSERT_TRUE(x.isSameShapeStrict(exp)); ASSERT_TRUE(x.equalsTo(&exp)); } @@ -145,7 +144,7 @@ TEST_F(HelpersTests1, Householder_mulLeft_test2) { ops::helpers::Householder::mulLeft(x, tail, 0.1); - ASSERT_TRUE(x.isSameShapeStrict(&exp)); + ASSERT_TRUE(x.isSameShapeStrict(exp)); ASSERT_TRUE(x.equalsTo(&exp)); } @@ -162,7 +161,7 @@ TEST_F(HelpersTests1, Householder_mulRight_test1) { ops::helpers::Householder::mulRight(x, tail, 0.1); - ASSERT_TRUE(x.isSameShapeStrict(&exp)); + ASSERT_TRUE(x.isSameShapeStrict(exp)); ASSERT_TRUE(x.equalsTo(&exp)); } @@ -181,9 +180,9 @@ TEST_F(HelpersTests1, BiDiagonalizeUp_test1) { ops::helpers::BiDiagonalUp object(matrix); // object._HHmatrix.printBuffer(); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(&object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(&object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } @@ -200,9 +199,9 @@ TEST_F(HelpersTests1, BiDiagonalizeUp_test2) { ops::helpers::BiDiagonalUp object(matrix); // object._HHmatrix.printBuffer(); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(&object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(&object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } @@ -219,9 +218,9 @@ TEST_F(HelpersTests1, BiDiagonalizeUp_test3) { ops::helpers::BiDiagonalUp object(matrix); // object._HHmatrix.printBuffer(); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(&object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(&object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } @@ -241,8 +240,8 @@ TEST_F(HelpersTests1, HHsequence_test1) { ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(&vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); @@ -268,8 +267,8 @@ TEST_F(HelpersTests1, HHsequence_test2) { ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(&vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); @@ -295,8 +294,8 @@ TEST_F(HelpersTests1, HHsequence_test3) { ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(&vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); @@ -870,9 +869,9 @@ TEST_F(HelpersTests1, SVD_test12) { ASSERT_TRUE(expU.equalsTo(&U)); ASSERT_TRUE(expV.equalsTo(&V)); - ASSERT_TRUE(expSingVals.isSameShapeStrict(&singVals)); - ASSERT_TRUE(expU.isSameShapeStrict(&U)); - ASSERT_TRUE(expV.isSameShapeStrict(&V)); + ASSERT_TRUE(expSingVals.isSameShapeStrict(singVals)); + ASSERT_TRUE(expU.isSameShapeStrict(U)); + ASSERT_TRUE(expV.isSameShapeStrict(V)); } /////////////////////////////////////////////////////////////////// @@ -893,9 +892,9 @@ TEST_F(HelpersTests1, SVD_test13) { ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(&qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(&qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(&qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } @@ -917,9 +916,9 @@ TEST_F(HelpersTests1, SVD_test14) { ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(&qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(&qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(&qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } @@ -941,9 +940,9 @@ TEST_F(HelpersTests1, SVD_test15) { ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(&qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(&qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(&qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } @@ -1246,9 +1245,9 @@ TEST_F(HelpersTests1, SVD_test16) { svd.DivideAndConquer(0, 3, 1, 1, 1); // svd._m.printIndexedBuffer(); - ASSERT_TRUE(expM.isSameShapeStrict(&svd._m)); - ASSERT_TRUE(expU.isSameShapeStrict(&svd._u)); - ASSERT_TRUE(expV.isSameShapeStrict(&svd._v)); + ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); + ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); + ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); ASSERT_TRUE(expM.equalsTo(&svd._m)); ASSERT_TRUE(expU.equalsTo(&svd._u)); @@ -1281,9 +1280,9 @@ TEST_F(HelpersTests1, SVD_test17) { ASSERT_TRUE(expU.equalsTo(&svd._u)); ASSERT_TRUE(expV.equalsTo(&svd._v)); - ASSERT_TRUE(expM.isSameShapeStrict(&svd._m)); - ASSERT_TRUE(expU.isSameShapeStrict(&svd._u)); - ASSERT_TRUE(expV.isSameShapeStrict(&svd._v)); + ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); + ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); + ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); } // /////////////////////////////////////////////////////////////////// @@ -1329,9 +1328,9 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expU.equalsTo(&svd._u)); // ASSERT_TRUE(expV.equalsTo(&svd._v)); -// ASSERT_TRUE(expS.isSameShapeStrict(&svd._s)); -// ASSERT_TRUE(expU.isSameShapeStrict(&svd._u)); -// ASSERT_TRUE(expV.isSameShapeStrict(&svd._v)); +// ASSERT_TRUE(expS.isSameShapeStrict(svd._s)); +// ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); +// ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } @@ -1378,9 +1377,9 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expU.equalsTo(&svd._u)); // ASSERT_TRUE(expV.equalsTo(&svd._v)); -// ASSERT_TRUE(expS.isSameShapeStrict(&svd._s)); -// ASSERT_TRUE(expU.isSameShapeStrict(&svd._u)); -// ASSERT_TRUE(expV.isSameShapeStrict(&svd._v)); +// ASSERT_TRUE(expS.isSameShapeStrict(svd._s)); +// ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); +// ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } @@ -1427,9 +1426,9 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expU.equalsTo(&svd._u)); // ASSERT_TRUE(expV.equalsTo(&svd._v)); -// ASSERT_TRUE(expS.isSameShapeStrict(&svd._s)); -// ASSERT_TRUE(expU.isSameShapeStrict(&svd._u)); -// ASSERT_TRUE(expV.isSameShapeStrict(&svd._v)); +// ASSERT_TRUE(expS.isSameShapeStrict(svd._s)); +// ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); +// ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } @@ -1444,7 +1443,7 @@ TEST_F(HelpersTests1, SVD_test17) { // ops::helpers::reverseArray(nd4j::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.getShapeInfo(), outArr.getBuffer(), outArr.getShapeInfo()); // // ASSERT_TRUE(outArr.equalsTo(&exp)); -// ASSERT_TRUE(outArr.isSameShapeStrict(&exp)); +// ASSERT_TRUE(outArr.isSameShapeStrict(exp)); //} // // @@ -1458,7 +1457,7 @@ TEST_F(HelpersTests1, SVD_test17) { // ops::helpers::reverseArray(nd4j::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.getShapeInfo(), inArr.getBuffer(), inArr.getShapeInfo()); // // ASSERT_TRUE(inArr.equalsTo(&exp)); -// ASSERT_TRUE(inArr.isSameShapeStrict(&exp)); +// ASSERT_TRUE(inArr.isSameShapeStrict(exp)); //} // // @@ -1472,7 +1471,7 @@ TEST_F(HelpersTests1, SVD_test17) { // ops::helpers::reverseArray(nd4j::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.getShapeInfo(), outArr.getBuffer(), outArr.getShapeInfo(), 5); // // ASSERT_TRUE(outArr.equalsTo(&exp)); -// ASSERT_TRUE(outArr.isSameShapeStrict(&exp)); +// ASSERT_TRUE(outArr.isSameShapeStrict(exp)); //} /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp index 96c480fd9..790279f74 100644 --- a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp @@ -42,18 +42,10 @@ TEST_F(IndexingTests, StridedSlice_1) { auto begin = NDArrayFactory::create({2,2, 0}); auto end = NDArrayFactory::create({3,3,3}); auto strides = NDArrayFactory::create({1,1,1}); - //nd4j_debug("print x->rankOf(): %i", x.rankOf()); - /* - auto tads = x.allTensorsAlongDimension({0}); - nd4j_debug("numTads: %i\n", tads->size()); - for (int e = 0; e < tads->size(); e++) - tads->at(e)->assign((float) e); - */ nd4j::ops::strided_slice op; -// auto result = op.execute({&x}, {}, {0,0,0,0,0, 2,2,0, 3,3,3, 1,1,1}); auto result = op.execute({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -202,8 +194,8 @@ TEST_F(IndexingTests, SimpleSlice_4) { TEST_F(IndexingTests, MaskedSlice_0) { auto matrix = NDArrayFactory::create('c', {3, 5}); auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - tads->at(e)->assign((float) (e+1)); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {1, 5}); @@ -222,15 +214,14 @@ TEST_F(IndexingTests, MaskedSlice_0) { ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(IndexingTests, MaskedSlice_00) { auto matrix = NDArrayFactory::create('c', {3, 5}); auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - tads->at(e)->assign((float) (e+1)); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {1, 2}, {2, 2}); @@ -243,21 +234,18 @@ TEST_F(IndexingTests, MaskedSlice_00) { auto z = result->at(0); - // z->printShapeInfo("z"); - ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(IndexingTests, MaskedSlice_1) { auto matrix = NDArrayFactory::create('c', {3, 5}); auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - tads->at(e)->assign((float) (e+1)); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {5}); @@ -276,7 +264,6 @@ TEST_F(IndexingTests, MaskedSlice_1) { ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(IndexingTests, MaskedSlice_2) { diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index e7f7f7e68..f058d9112 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -834,12 +834,17 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) { auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dims}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dims.dataBuffer()); - execReduce3Tad(extraPointers, 2, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dims.shapeInfo(), dims.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); NDArray::registerSpecialUse({&z}, {&x, &y, &dims}); @@ -981,10 +986,14 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) { NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY}); + OpaqueDataBuffer xBuf(arrayX.dataBuffer()); + OpaqueDataBuffer yBuf(arrayY.dataBuffer()); + OpaqueDataBuffer zBuf(arrayZ.dataBuffer()); + execPairwiseTransform(nullptr, pairwise::Add, - arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(), - arrayY.buffer(), arrayY.shapeInfo(), arrayY.getSpecialBuffer(), arrayY.getSpecialShapeInfo(), - arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(), + &xBuf, arrayX.shapeInfo(), arrayX.getSpecialShapeInfo(), + &yBuf, arrayY.shapeInfo(), arrayY.getSpecialShapeInfo(), + &zBuf, arrayZ.shapeInfo(), arrayZ.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY}); @@ -1220,28 +1229,28 @@ TEST_F(JavaInteropTests, test_bfloat16_rng) { auto z = NDArrayFactory::create('c', {10}); RandomGenerator rng(119, 323841120L); bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f}; - execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), args); + OpaqueDataBuffer zBuf(z.dataBuffer()); + execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), args); //z.printIndexedBuffer("z"); - ASSERT_TRUE(z.sumNumber().e(0) > 0); } TEST_F(JavaInteropTests, test_ismax_view) { auto original = NDArrayFactory::create('c', {2, 3, 40}); auto v = original.subarray({NDIndex::all(), NDIndex::all(), NDIndex::interval(0, 40, 2)}); - v->assign(1.0); + v.assign(1.0); - auto e = v->like(); + auto e = v.like(); auto t = e.tensorAlongDimension(0, {0, 1}); - t->assign(1.0); + t.assign(1.0); - auto z = v->ulike(); + auto z = v.ulike(); Nd4jLong iArgs[] = {2L, 0L}; Context ctx(1); - ctx.setInputArray(0, v->buffer(), v->shapeInfo(), v->specialBuffer(), v->specialShapeInfo()); + ctx.setInputArray(0, v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo()); ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); ctx.setIArguments(iArgs, 1); @@ -1249,9 +1258,6 @@ TEST_F(JavaInteropTests, test_ismax_view) { op.execute(&ctx); ASSERT_EQ(e, z); - - delete v; - delete t; } TEST_F(JavaInteropTests, test_size_dtype_1) { @@ -1270,6 +1276,64 @@ TEST_F(JavaInteropTests, test_size_dtype_1) { ASSERT_EQ(e, z); } +TEST_F(JavaInteropTests, test_expandable_array_op_1) { + auto x = NDArrayFactory::string('c', {2}, {"first string", "second"}); + auto d = NDArrayFactory::string(" "); + + auto z0 = NDArrayFactory::create('c', {6}); + auto z1 = NDArrayFactory::string('c', {3}, {"", "", ""}); + + auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); + auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"}); + + InteropDataBuffer iz0(z0.dataBuffer()); + InteropDataBuffer iz1(z1.dataBuffer()); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo()); + ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo()); + ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo()); + + nd4j::ops::compat_string_split op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp0, z0); + ASSERT_EQ(exp1, z1); +} + +TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) { + if (!Environment::getInstance()->isCPU()) + return; + + auto x = NDArrayFactory::create('c', {4, 3, 4, 4}); + auto y = NDArrayFactory::create('c', {4, 3, 3, 3}); + auto z = NDArrayFactory::create('c', {4, 3, 4, 4}); + + double buffer[2048]; + + InteropDataBuffer ix(0, DataType::DOUBLE, false); + InteropDataBuffer iy(0, DataType::DOUBLE, false); + InteropDataBuffer iz(0, DataType::DOUBLE, false); + + // we're imitating workspace-managed array here + ix.setPrimary(buffer + 64, x.lengthOf()); + iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf()); + iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf()); + + Context ctx(1); + ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo()); + ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo()); + ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo()); + + ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); + + nd4j::ops::maxpool2d_bp op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); +} + /* TEST_F(JavaInteropTests, Test_Results_Conversion_1) { auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); diff --git a/libnd4j/tests_cpu/layers_tests/LambdaTests.cu b/libnd4j/tests_cpu/layers_tests/LambdaTests.cu index 30244b7dc..5bf8c8b57 100644 --- a/libnd4j/tests_cpu/layers_tests/LambdaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/LambdaTests.cu @@ -75,7 +75,7 @@ void test(NDArray &x) { return x+1.; }; - x.applyLambda(f, &x); + x.applyLambda(f, x); } template @@ -84,7 +84,7 @@ void test2(NDArray &x) { return x+1.; }; - x.applyLambda(f, &x); + x.applyLambda(f, x); } void testPairwise(NDArray &x, NDArray &y) { @@ -92,7 +92,7 @@ void testPairwise(NDArray &x, NDArray &y) { return x + y +1.; }; - x.applyPairwiseLambda(&y, f, &x); + x.applyPairwiseLambda(y, f, x); } void testTriplewise(NDArray &i, NDArray &j, NDArray &k) { @@ -100,7 +100,7 @@ void testTriplewise(NDArray &i, NDArray &j, NDArray &k) { return i + j + k + 2.; }; - i.applyTriplewiseLambda(&j, &k, f, &i); + i.applyTriplewiseLambda(j, k, f, i); } void testIndexed(NDArray &x) { @@ -108,7 +108,7 @@ void testIndexed(NDArray &x) { return _idx + 1.; }; - x.applyIndexedLambda(f, &x); + x.applyIndexedLambda(f, x); } void testIndexedPairwise(NDArray &x, NDArray &y) { @@ -116,7 +116,7 @@ void testIndexedPairwise(NDArray &x, NDArray &y) { return _idx + x + y +1.; }; - x.applyIndexedPairwiseLambda(&y, f, &x); + x.applyIndexedPairwiseLambda(y, f, x); } TEST_F(LambdaTests, test_basic_2) { @@ -197,7 +197,7 @@ void testPairwiseMy(NDArray &x, NDArray &y, NDArray &z) { + nd4j::math::nd4j_exp(-nd4j::math::nd4j_abs(x))); }; - x.applyPairwiseLambda(&y, f, &z); + x.applyPairwiseLambda(y, f, z); } /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index ffcd5759e..f0b7628ee 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -194,11 +194,10 @@ TEST_F(LegacyOpsTests, ReduceTests_2) { auto exp = x.reduceAlongDimension(reduce::Sum, {1}); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete exp; } @@ -211,7 +210,7 @@ TEST_F(LegacyOpsTests, ReduceTests_3) { nd4j::ops::LegacyReduceSameOp op(reduce::Sum); auto result = op.execute({&x, &indices}, {}, {}); auto z = result->at(0); - auto exp = x.reduceAlongDims(reduce::Sum,{1}); + auto exp = x.reduceAlongDimension(reduce::Sum,{1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -231,7 +230,7 @@ TEST_F(LegacyOpsTests, ReduceTests_4) { nd4j::ops::LegacyReduceSameOp op(reduce::Sum); auto result = op.execute({&x, &indices}, {}, {}, {true}); auto z = result->at(0); - auto exp = x.reduceAlongDims(reduce::Sum, {1}, true); + auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true); // indices.printShapeInfo("Indices shape"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); // z->printIndexedBuffer("Output reduce 4"); @@ -275,11 +274,10 @@ TEST_F(LegacyOpsTests, ReduceTests_6) { auto exp = x.reduceAlongDimension(reduce::Mean, {1}); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete exp; } @@ -292,7 +290,7 @@ TEST_F(LegacyOpsTests, ReduceTests_7) { nd4j::ops::LegacyReduceFloatOp op(reduce::Mean); auto result = op.execute({&x, &indices}, {}, {}); auto z = result->at(0); - auto exp = x.reduceAlongDims(reduce::Mean,{1}); + auto exp = x.reduceAlongDimension(reduce::Mean,{1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -312,7 +310,7 @@ TEST_F(LegacyOpsTests, ReduceTests_8) { nd4j::ops::LegacyReduceFloatOp op(reduce::Mean); auto result = op.execute({&x, &indices}, {}, {}, {true}); auto z = result->at(0); - auto exp = x.reduceAlongDims(reduce::Mean, {1}, true); + auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true); ASSERT_EQ(ND4J_STATUS_OK, result->status()); // z->printIndexedBuffer("Reduce8 output"); @@ -382,10 +380,8 @@ TEST_F(LegacyOpsTests, BroadcastingTests_1) { auto list = x.allTensorsAlongDimension({1}); // x.printIndexedBuffer("Output broadcast"); // list->at(0)->printIndexedBuffer("Column 0:"); - for (int e = 0; e < list->size(); e++) - ASSERT_TRUE(row.equalsTo(list->at(e))); - - delete list; + for (int e = 0; e < list.size(); e++) + ASSERT_TRUE(row.equalsTo(list.at(e))); } TEST_F(LegacyOpsTests, BroadcastingTests_2) { @@ -417,7 +413,7 @@ TEST_F(LegacyOpsTests, PowDerivative_1) { float p = 2.0f; - x.applyScalar(scalar::PowDerivative, p); + x.applyScalar(scalar::PowDerivative, p, x); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -474,12 +470,16 @@ TEST_F(LegacyOpsTests, Reduce3_2) { auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); execReduce3Tad(extraPointers, reduce3::CosineSimilarity, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); @@ -510,14 +510,17 @@ TEST_F(LegacyOpsTests, Reduce3_3) { auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); execReduce3Tad(extraPointers, reduce3::CosineDistance, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); ASSERT_EQ(e, z); NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); @@ -547,14 +550,17 @@ TEST_F(LegacyOpsTests, Reduce3_4) { auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); execReduce3Tad(extraPointers, reduce3::CosineDistance, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); // z.printIndexedBuffer("z"); @@ -587,13 +593,16 @@ TEST_F(LegacyOpsTests, Reduce3_5) { NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); execReduce3Tad(extraPointers, reduce3::CosineDistance, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); @@ -619,10 +628,15 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) { NDArray::prepareSpecialUse({&z}, {&x, &y}); - execReduce3All(extraPointers, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); + + execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), tadPackX.platformShapeInfo(), tadPackX.platformOffsets(), tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); @@ -661,10 +675,10 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) { e.assign(false); auto row = y.tensorAlongDimension(1, {1}); - row->assign(2.0f); + row.assign(2.0f); auto erow = e.tensorAlongDimension(1, {1}); - erow->assign(true); + erow.assign(true); auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), 1); @@ -680,9 +694,6 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) { tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); ASSERT_EQ(e, z); - - delete row; - delete erow; } TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) { @@ -737,13 +748,16 @@ TEST_F(LegacyOpsTests, test_legacy_reduce_empty_4) { auto z = NDArrayFactory::create('c', {0, 2}); auto e = NDArrayFactory::create('c', {0, 2}); + InteropDataBuffer xdb(x.dataBuffer()); + InteropDataBuffer ddb(d.dataBuffer()); + InteropDataBuffer zdb(z.dataBuffer()); ::execReduceSame2(nullptr, reduce::SameOps::Sum, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + &xdb, x.shapeInfo(), x.specialShapeInfo(), nullptr, - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo()); + &zdb, z.shapeInfo(), z.specialShapeInfo(), + &ddb, d.shapeInfo(), d.specialShapeInfo()); } diff --git a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp index e506839df..625d9978f 100644 --- a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -59,7 +59,7 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); list.write(e, row); - tads->at(e)->assign(row); + tads.at(e)->assign(row); } nd4j::ops::stack_list op; @@ -75,7 +75,6 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) { ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { @@ -86,7 +85,7 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); //list.write(e, row); - tads->at(e)->assign(row); + tads.at(e)->assign(row); delete row; } @@ -103,13 +102,12 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { // ASSERT_TRUE(exp.equalsTo(z)); for (int e = 0; e < 10; e++) { auto row = list.read(e); - ASSERT_TRUE(row->equalsTo(tads->at(e))); + ASSERT_TRUE(row->equalsTo(tads.at(e))); //list.write(e, row); delete row; } delete result; - delete tads; } //TEST_F(ListOperationsTests, BasicTest_UnStackList_2) { @@ -153,7 +151,7 @@ TEST_F(ListOperationsTests, BasicTest_Read_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {1, 100}); row->assign((double) e); - list.write(e, row->dup()); + list.write(e, new NDArray(row->dup())); delete row; } @@ -179,16 +177,16 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); - list.write(e, row->dup()); + list.write(e, new NDArray(row->dup())); delete row; } auto tads = exp.allTensorsAlongDimension({1}); - tads->at(0)->assign(1.0f); - tads->at(1)->assign(1.0f); - tads->at(2)->assign(3.0f); - tads->at(3)->assign(3.0f); + tads.at(0)->assign(1.0f); + tads.at(1)->assign(1.0f); + tads.at(2)->assign(3.0f); + tads.at(3)->assign(3.0f); nd4j::ops::pick_list op; @@ -202,7 +200,6 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) { ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(ListOperationsTests, BasicTest_Size_1) { @@ -211,7 +208,7 @@ TEST_F(ListOperationsTests, BasicTest_Size_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); - list.write(e, row->dup()); + list.write(e, new NDArray(row->dup())); delete row; } @@ -272,14 +269,14 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {5}); row->assign((double) e); - tads->at(e)->assign(row); + tads.at(e)->assign(row); if (e < 2) - tads0->at(cnt0++)->assign(row); + tads0.at(cnt0++)->assign(row); else if (e < 5) - tads1->at(cnt1++)->assign(row); + tads1.at(cnt1++)->assign(row); else - tads2->at(cnt2++)->assign(row); + tads2.at(cnt2++)->assign(row); delete row; } @@ -300,10 +297,6 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) { ASSERT_TRUE(exp2.equalsTo(list.readRaw(2))); delete result; - delete tads; - delete tads0; - delete tads1; - delete tads2; } TEST_F(ListOperationsTests, BasicTest_Scatter_1) { @@ -315,7 +308,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {1, 5}); row->assign((double) e); - tads->at(e)->assign(row); + tads.at(e)->assign(row); delete row; } @@ -329,15 +322,13 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); for (int e = 0; e < 10; e++) { - auto row = tads->at(9 - e); + auto row = tads.at(9 - e); auto chunk = list.readRaw(e); ASSERT_TRUE(chunk->isSameShape(row)); ASSERT_TRUE(chunk->equalsTo(row)); } - - delete tads; delete result; } @@ -376,7 +367,7 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {3}); row->assign((double) e); - list.write(e, row->dup()); + list.write(e, new NDArray(row->dup())); delete row; } @@ -384,7 +375,7 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) { auto exp = NDArrayFactory::create('c', {10, 3}); auto tads = exp.allTensorsAlongDimension({1}); for (int e = 0; e < 10; e++) { - auto tad = tads->at(9 - e); + auto tad = tads.at(9 - e); tad->assign(e); } @@ -407,7 +398,6 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) { ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(ListOperationsTests, GraphTests_Sequential_1) { @@ -415,17 +405,16 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_1) { auto matrix = NDArrayFactory::create_('c', {3, 3}); auto tads = matrix->allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - tads->at(e)->assign((float) (e+1)); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {3, 3}); auto tadsExp = exp.allTensorsAlongDimension({1}); - tadsExp->at(0)->assign(0.f); - tadsExp->at(1)->assign(-1.f); - tadsExp->at(2)->assign(-2.f); - delete tadsExp; + tadsExp.at(0)->assign(0.f); + tadsExp.at(1)->assign(-1.f); + tadsExp.at(2)->assign(-2.f); auto indices = NDArrayFactory::valueOf({3}, 1, 'c'); //indices->linspace(0); @@ -472,7 +461,7 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_1) { // nodeF1->setCustomOp(&opF); // nodeF2->setCustomOp(&opF); - // now we're stacking chunks back to matrix state + // now we're stacking chunks back to matrix state nd4j::ops::stack_list opG; auto nodeG = new Node(&opG, 20, {2, 15, 16, 17}); //auto nodeG = new Node(OpType_CUSTOM, 0, 20, {2}); @@ -537,8 +526,6 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_1) { ASSERT_TRUE(exp.isSameShape(stack)); ASSERT_TRUE(exp.equalsTo(stack)); - - delete tads; } @@ -548,16 +535,15 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_2) { auto scalar = NDArrayFactory::create_(0.0f); auto matrix = NDArrayFactory::create_('c', {3, 3}); auto tads = matrix->allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - tads->at(e)->assign((float) (e+1)); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); } - auto exp = NDArrayFactory::create('c', {3, 3}); auto tadsExp = exp.allTensorsAlongDimension({1}); - tadsExp->at(0)->assign(0.f); - tadsExp->at(1)->assign(-1.f); - tadsExp->at(2)->assign(-2.f); + tadsExp.at(0)->assign(0.f); + tadsExp.at(1)->assign(-1.f); + tadsExp.at(2)->assign(-2.f); //auto indices = NDArray::valueOf({1, 3}, 1.0f, 'c'); auto indices = NDArrayFactory::create_('c', {1, 3}); @@ -580,7 +566,7 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_2) { // filling list with matrix nd4j::ops::scatter_list opC; auto nodeC = new Node(&opC, 3, {2, -2, 1, -3}); - + //nodeC->setCustomOp(&opC); nd4j::ops::read_list opD; @@ -608,7 +594,7 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_2) { // nodeF1->setCustomOp(&opF); // nodeF2->setCustomOp(&opF); - // now we're gathering chunks back to matrix state + // now we're gathering chunks back to matrix state nd4j::ops::pick_list opG; auto nodeG = new Node(&opG, 20, {2, -2, 15, 16, 17}); //auto nodeG = new Node(OpType_CUSTOM, 0, 20, {2}); @@ -665,14 +651,11 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_2) { ASSERT_EQ(3, list->elements()); ASSERT_TRUE(variableSpace->hasVariable(20)); - + auto stack = variableSpace->getVariable(20)->getNDArray(); - + ASSERT_TRUE(stack != nullptr); ASSERT_TRUE(exp.isSameShape(stack)); ASSERT_TRUE(exp.equalsTo(stack)); - - delete tadsExp; - delete tads; } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index 829117bed..d83e85f67 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -45,26 +45,26 @@ static void printer(std::initializer_list TEST_F(MklDnnTests, helpers_includer) { // we need this block, to make sure all helpers are still available within binary, and not optimized out by linker #ifdef HAVE_MKLDNN - nd4j::ops::platforms::PLATFORM_conv2d conv2d; - nd4j::ops::platforms::PLATFORM_conv2d_bp conv2d_bp; + nd4j::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv2d; + nd4j::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv2d_bp; - nd4j::ops::platforms::PLATFORM_conv2d conv3d; - nd4j::ops::platforms::PLATFORM_conv2d_bp conv3d_bp; + nd4j::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv3d; + nd4j::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv3d_bp; - nd4j::ops::platforms::PLATFORM_avgpool2d avgpool2d; - nd4j::ops::platforms::PLATFORM_avgpool2d_bp avgpool2d_bp; + nd4j::ops::platforms::PLATFORM_avgpool2d_ENGINE_CPU avgpool2d; + nd4j::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CPU avgpool2d_bp; - nd4j::ops::platforms::PLATFORM_maxpool2d maxpool2d; - nd4j::ops::platforms::PLATFORM_maxpool2d_bp maxpool2d_bp; + nd4j::ops::platforms::PLATFORM_maxpool2d_ENGINE_CPU maxpool2d; + nd4j::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CPU maxpool2d_bp; - nd4j::ops::platforms::PLATFORM_avgpool3dnew avgpool3d; - nd4j::ops::platforms::PLATFORM_avgpool3dnew_bp avgpool3d_bp; + nd4j::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CPU avgpool3d; + nd4j::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CPU avgpool3d_bp; - nd4j::ops::platforms::PLATFORM_maxpool3dnew maxpool3d; - nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp maxpool3d_bp; + nd4j::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CPU maxpool3d; + nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CPU maxpool3d_bp; - nd4j::ops::platforms::PLATFORM_lrn lrn; - nd4j::ops::platforms::PLATFORM_batchnorm batchnorm; + nd4j::ops::platforms::PLATFORM_lrn_ENGINE_CPU lrn; + nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CPU batchnorm; printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm}); #endif diff --git a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp index d271048a9..9afc34267 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -237,18 +237,14 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) { NDArray exp2('c', {1,1}, {1}, nd4j::DataType::INT64); NDArray exp3('c', {2}, {1,2}, nd4j::DataType::INT64); - auto* scalar1 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {}/*whole range*/); - ASSERT_EQ(*scalar1, exp1); + auto scalar1 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {}/*whole range*/); + ASSERT_EQ(scalar1, exp1); - auto* scalar2 = x.reduceAlongDimension(nd4j::reduce::CountZero, {}/*whole range*/, true); - ASSERT_EQ(*scalar2, exp2); + auto scalar2 = x.reduceAlongDimension(nd4j::reduce::CountZero, {}/*whole range*/, true); + ASSERT_EQ(scalar2, exp2); - auto* scalar3 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {1}); - ASSERT_EQ(*scalar3, exp3); - - delete scalar1; - delete scalar2; - delete scalar3; + auto scalar3 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {1}); + ASSERT_EQ(scalar3, exp3); } //////////////////////////////////////////////////////////////////////////////// @@ -257,16 +253,13 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) { NDArray exp1('c', {}, {1.5}, nd4j::DataType::FLOAT32); NDArray exp2('c', {2}, {0.5,2.5}, nd4j::DataType::FLOAT32); - auto* scalar1 = x.reduceAlongDimension(nd4j::reduce::Mean, {}/*whole range*/); + auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Mean, {}/*whole range*/); // scalar1->printShapeInfo(); // scalar1->printIndexedBuffer(); - ASSERT_EQ(*scalar1, exp1); + ASSERT_EQ(scalar1, exp1); - auto* scalar2 = x.reduceAlongDimension(nd4j::reduce::Mean, {1}); - ASSERT_EQ(*scalar2, exp2); - - delete scalar1; - delete scalar2; + auto scalar2 = x.reduceAlongDimension(nd4j::reduce::Mean, {1}); + ASSERT_EQ(scalar2, exp2); } //////////////////////////////////////////////////////////////////////////////// @@ -275,10 +268,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) { NDArray exp1('c', {}, {8.}, nd4j::DataType::HALF); NDArray exp2('c', {2}, {2.,6.}, nd4j::DataType::HALF); - auto scalar1 = x.reduceAlongDims(nd4j::reduce::Sum, {}/*whole range*/); + auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Sum, {}/*whole range*/); ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDims(nd4j::reduce::Sum, {1}); + auto scalar2 = x.reduceAlongDimension(nd4j::reduce::Sum, {1}); ASSERT_EQ(scalar2, exp2); } @@ -288,10 +281,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) { NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL); NDArray exp2('c', {2}, {1,0}, nd4j::DataType::BOOL); - auto scalar1 = x.reduceAlongDims(nd4j::reduce::IsPositive, {}/*whole range*/); + auto scalar1 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {}/*whole range*/); ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDims(nd4j::reduce::IsPositive, {1}); + auto scalar2 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {1}); ASSERT_EQ(scalar2, exp2); } @@ -974,22 +967,22 @@ TEST_F(MultiDataTypeTests, ndarray_applyTransformFloat_test1) { NDArray result2('c', {2,2}, nd4j::DataType::DOUBLE); NDArray result3('c', {2,2}, nd4j::DataType::HALF); - x1.applyTransform(nd4j::transform::Sqrt, &result1); + x1.applyTransform(nd4j::transform::Sqrt, result1); ASSERT_EQ(result1, exp1); - x2.applyTransform(nd4j::transform::Sqrt, &result2); + x2.applyTransform(nd4j::transform::Sqrt, result2); ASSERT_EQ(result2, exp2); - x3.applyTransform(nd4j::transform::Sqrt, &result3); + x3.applyTransform(nd4j::transform::Sqrt, result3); ASSERT_EQ(result3, exp3); - x4.applyTransform(nd4j::transform::Sqrt, &result3); + x4.applyTransform(nd4j::transform::Sqrt, result3); ASSERT_EQ(result3, exp4); - x2.applyTransform(nd4j::transform::Sqrt); + x2.applyTransform(nd4j::transform::Sqrt, x2); ASSERT_EQ(x2, exp3); - x3.applyTransform(nd4j::transform::Sqrt); + x3.applyTransform(nd4j::transform::Sqrt, x3); ASSERT_EQ(x3, exp2); } @@ -1016,25 +1009,25 @@ TEST_F(MultiDataTypeTests, ndarray_applyTransformSame_test1) { NDArray result4('c', {2,2}, nd4j::DataType::BOOL); NDArray result5('c', {3,2}, nd4j::DataType::DOUBLE); - x1.applyTransform(nd4j::transform::Square, &result1); + x1.applyTransform(nd4j::transform::Square, result1); ASSERT_EQ(result1, exp1); - x2.applyTransform(nd4j::transform::Square, &result2); + x2.applyTransform(nd4j::transform::Square, result2); ASSERT_EQ(result2, exp2); - x3.applyTransform(nd4j::transform::Square, &result3); + x3.applyTransform(nd4j::transform::Square, result3); ASSERT_EQ(result3, exp3); - x4.applyTransform(nd4j::transform::Square, &result4); + x4.applyTransform(nd4j::transform::Square, result4); ASSERT_EQ(result4, exp4); - x2.applyTransform(nd4j::transform::Square); + x2.applyTransform(nd4j::transform::Square, x2); ASSERT_EQ(x2, exp2); - x3.applyTransform(nd4j::transform::Square); + x3.applyTransform(nd4j::transform::Square, x3); ASSERT_EQ(x3, exp3); - x5.applyTransform(nd4j::transform::Square, &result5); + x5.applyTransform(nd4j::transform::Square, result5); ASSERT_EQ(result5, exp5); } @@ -1057,19 +1050,19 @@ TEST_F(MultiDataTypeTests, ndarray_applyTransformBool_test1) { NDArray result2('c', {3,2}, nd4j::DataType::BOOL); /* - x1.applyTransform(nd4j::transform::IsMax, &result1); + x1.applyTransform(nd4j::transform::IsMax, result1); ASSERT_EQ(result1, exp1); - x2.applyTransform(nd4j::transform::IsMax, &result1); + x2.applyTransform(nd4j::transform::IsMax, result1); ASSERT_EQ(result1, exp1); - x3.applyTransform(nd4j::transform::IsMax, &result1); + x3.applyTransform(nd4j::transform::IsMax, result1); ASSERT_EQ(result1, exp1); - x4.applyTransform(nd4j::transform::IsMax, &result1); + x4.applyTransform(nd4j::transform::IsMax, result1); ASSERT_EQ(result1, exp2); - x5.applyTransform(nd4j::transform::IsMax, &result2); + x5.applyTransform(nd4j::transform::IsMax, result2); ASSERT_EQ(result2, exp3); */ } @@ -1095,28 +1088,28 @@ TEST_F(MultiDataTypeTests, ndarray_applyTransformStrict_test1) { NDArray result3('c', {2,2}, nd4j::DataType::DOUBLE); NDArray result4('c', {3,2}, nd4j::DataType::DOUBLE); - x1.applyTransform(nd4j::transform::CubeDerivative, &result1); + x1.applyTransform(nd4j::transform::CubeDerivative, result1); ASSERT_EQ(result1, exp1); - x2.applyTransform(nd4j::transform::CubeDerivative, &result2); + x2.applyTransform(nd4j::transform::CubeDerivative, result2); ASSERT_EQ(result2, exp2); - x3.applyTransform(nd4j::transform::CubeDerivative, &result3); + x3.applyTransform(nd4j::transform::CubeDerivative, result3); ASSERT_EQ(result3, exp3); - x4.applyTransform(nd4j::transform::CubeDerivative, &result4); + x4.applyTransform(nd4j::transform::CubeDerivative, result4); ASSERT_EQ(result4, exp4); - x1.applyTransform(nd4j::transform::CubeDerivative); + x1.applyTransform(nd4j::transform::CubeDerivative, x1); ASSERT_EQ(x1, exp1); - x2.applyTransform(nd4j::transform::CubeDerivative); + x2.applyTransform(nd4j::transform::CubeDerivative, x2); ASSERT_EQ(x2, exp2); - x3.applyTransform(nd4j::transform::CubeDerivative); + x3.applyTransform(nd4j::transform::CubeDerivative, x3); ASSERT_EQ(x3, exp3); - x4.applyTransform(nd4j::transform::CubeDerivative); + x4.applyTransform(nd4j::transform::CubeDerivative, x4); ASSERT_EQ(x4, exp5); } @@ -1138,19 +1131,19 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseTransform_test1) { NDArray exp4('c', {2,3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, nd4j::DataType::DOUBLE); NDArray exp5('c', {3,2}, {0, 2, 4, 6, 8, 5}, nd4j::DataType::INT32); - x1.applyPairwiseTransform(nd4j::pairwise::Add, &x4, &x5, nullptr); + x1.applyPairwiseTransform(nd4j::pairwise::Add, x4, x5); ASSERT_EQ(x5, exp5); - x1.applyPairwiseTransform(nd4j::pairwise::Add, &x4, &x6, nullptr); + x1.applyPairwiseTransform(nd4j::pairwise::Add, x4, x6); ASSERT_EQ(x6, exp4); - x1.applyPairwiseTransform(nd4j::pairwise::Add, x4, nullptr); + x1.applyPairwiseTransform(nd4j::pairwise::Add, x4); ASSERT_EQ(x1, exp1); - x2.applyPairwiseTransform(nd4j::pairwise::Add, x4, nullptr); + x2.applyPairwiseTransform(nd4j::pairwise::Add, x4); ASSERT_EQ(x2, exp2); - x3.applyPairwiseTransform(nd4j::pairwise::Add, x4, nullptr); + x3.applyPairwiseTransform(nd4j::pairwise::Add, x4); ASSERT_EQ(x3, exp3); } @@ -1173,13 +1166,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseTransform_test2) { NDArray exp2('c', {2,3}, {1, 0, 1, 1, 0, 1}, nd4j::DataType::BOOL); NDArray exp3('c', {2,3}, {0, 1, 0, 0, 0, 0}, nd4j::DataType::BOOL); - x1.applyPairwiseTransform(nd4j::pairwise::EqualTo, &x2, &x7, nullptr); + x1.applyPairwiseTransform(nd4j::pairwise::EqualTo, x2, x7); ASSERT_EQ(x7, exp1); - x3.applyPairwiseTransform(nd4j::pairwise::EqualTo, &x4, &x8, nullptr); + x3.applyPairwiseTransform(nd4j::pairwise::EqualTo, x4, x8); ASSERT_EQ(x8, exp2); - x5.applyPairwiseTransform(nd4j::pairwise::EqualTo, &x6, &x8, nullptr); + x5.applyPairwiseTransform(nd4j::pairwise::EqualTo, x6, x8); ASSERT_EQ(x8, exp3); } @@ -1199,13 +1192,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyBroadcast_test1) { NDArray exp2('c', {2,3}, {11, 21, 31, 42, 52, 62}, nd4j::DataType::FLOAT32); NDArray exp3('c', {2,3}, {11, 21, 31, 41, 51, 61}, nd4j::DataType::INT32); - x1.applyBroadcast(nd4j::broadcast::Add, {0}, &x2, &x3); + x1.applyBroadcast(nd4j::broadcast::Add, {0}, x2, x3); ASSERT_EQ(x3, exp1); - x1.applyBroadcast(nd4j::broadcast::Add, {0}, &x4, &x5); + x1.applyBroadcast(nd4j::broadcast::Add, {0}, x4, x5); ASSERT_EQ(x5, exp2); - x1.applyBroadcast(nd4j::broadcast::Add, {0}, &x6, &x3); + x1.applyBroadcast(nd4j::broadcast::Add, {0}, x6, x3); ASSERT_EQ(x3, exp3); } @@ -1222,10 +1215,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyBroadcast_test2) { NDArray exp1('c', {2,3}, {1, 0, 0, 0, 0, 1}, nd4j::DataType::BOOL); NDArray exp2('c', {2,3}, {1, 1, 1, 0, 0, 1}, nd4j::DataType::BOOL); - x1.applyBroadcast(nd4j::broadcast::EqualTo, {0}, &x2, &x3); + x1.applyBroadcast(nd4j::broadcast::EqualTo, {0}, x2, x3); ASSERT_EQ(x3, exp1); - x4.applyBroadcast(nd4j::broadcast::EqualTo, {0}, &x5, &x3); + x4.applyBroadcast(nd4j::broadcast::EqualTo, {0}, x5, x3); ASSERT_EQ(x3, exp2); } @@ -1256,13 +1249,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test1) { NDArray exp4('c', {0}, {4.5}, nd4j::DataType::DOUBLE); NDArray exp5('c', {2,2}, {11.5, 21.5, 31.5, 41.5}, nd4j::DataType::DOUBLE); - x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x2, &x3, true); + x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x2, x3); ASSERT_EQ(x3, exp1); - x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x4, &x5, true); + x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x4, x5); ASSERT_EQ(x5, exp2); - x6.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x7, &x8, true); + x6.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x7, x8); ASSERT_EQ(x8, exp3); auto x9 = x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x2); @@ -1274,17 +1267,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test1) { auto x11 = x6.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x7); ASSERT_EQ(x11, exp3); - auto x12 = x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x2); - ASSERT_EQ(*x12, exp1); - delete x12; + auto x12 = x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x2); + ASSERT_EQ(x12, exp1); - x13.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x14, &x15, true); + x13.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x14, x15); ASSERT_EQ(x15, exp4); - x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x14, &x16, true); + x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x14, x16); ASSERT_EQ(x16, exp5); - x14.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x1, &x16, true); + x14.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x1, x16); ASSERT_EQ(x16, exp5); } @@ -1305,16 +1297,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test2) { NDArray exp2('c', {2,2}, {1, 0, 0, 0}, nd4j::DataType::BOOL); NDArray exp3('c', {0}, {0}, nd4j::DataType::BOOL); - x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), &x2, &x3, true); + x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x2, x3); ASSERT_EQ(x3, exp1); - x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), &x4, &x3, true); + x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x4, x3); ASSERT_EQ(x3, exp2); - x4.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), &x1, &x3, true); + x4.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x1, x3); ASSERT_EQ(x3, exp2); - x5.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), &x4, &x6, true); + x5.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x4, x6); ASSERT_EQ(x6, exp3); } @@ -1334,19 +1326,19 @@ TEST_F(MultiDataTypeTests, ndarray_applyScalar_test1) { NDArray exp4('c', {2,2}, {1.1, 2.1, 1.1, 2.1}, nd4j::DataType::DOUBLE); NDArray exp5('c', {2,2}, {1, 1, 1, 1}, nd4j::DataType::BOOL); - x1.applyScalar(nd4j::scalar::Add, 1); + x1.applyScalar(nd4j::scalar::Add, 1, x1); ASSERT_EQ(x1, exp1); - x1.applyScalar(nd4j::scalar::Add, 0.5, &x3); + x1.applyScalar(nd4j::scalar::Add, 0.5, x3); ASSERT_EQ(x3, exp2); - x2.applyScalar(nd4j::scalar::Add, 0.1); + x2.applyScalar(nd4j::scalar::Add, 0.1, x2); ASSERT_EQ(x2, exp3); - x4.applyScalar(nd4j::scalar::Add, 1.1, &x3); + x4.applyScalar(nd4j::scalar::Add, 1.1, x3); ASSERT_EQ(x3, exp4); - x4.applyScalar(nd4j::scalar::Add, 1); + x4.applyScalar(nd4j::scalar::Add, 1, x4); ASSERT_EQ(x4, exp5); } @@ -1362,13 +1354,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyScalar_test2) { NDArray exp1('c', {2,2}, {0, 1, 0, 0}, nd4j::DataType::BOOL); NDArray exp2('c', {2,2}, {0, 1, 1, 0}, nd4j::DataType::BOOL); - x1.applyScalar(nd4j::scalar::EqualTo, 1, &x4); + x1.applyScalar(nd4j::scalar::EqualTo, 1, x4); ASSERT_EQ(x4, exp1); - x2.applyScalar(nd4j::scalar::EqualTo, 1.5, &x4); + x2.applyScalar(nd4j::scalar::EqualTo, 1.5, x4); ASSERT_EQ(x4, exp1); - x3.applyScalar(nd4j::scalar::EqualTo, true, &x4); + x3.applyScalar(nd4j::scalar::EqualTo, true, x4); ASSERT_EQ(x4, exp2); } @@ -1399,22 +1391,23 @@ TEST_F(MultiDataTypeTests, ndarray_applyLambda_test1) { NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2,2}, {1, 0, 0, 0}, nd4j::DataType::BOOL); - x1.applyLambda(func1, &x4); + x1.applyLambda(func1, x4); ASSERT_EQ(x4, exp1); - x2.applyLambda(func1); + x2.applyLambda(func1, x2); ASSERT_EQ(x2, exp2); - x2.applyLambda(func2); + x2.applyLambda(func2, x2); ASSERT_EQ(x2, exp2); - x3.applyLambda(func3); + x3.applyLambda(func3, x3); ASSERT_EQ(x3, exp3); - x5.applyLambda(func4); + x5.applyLambda(func4, x5); + // x5.printBuffer(); ASSERT_EQ(x5, exp4); - x6.applyLambda(func5, &x7); + x6.applyLambda(func5, x7); ASSERT_EQ(x7, exp5); } @@ -1444,22 +1437,22 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedLambda_test1) { NDArray exp5('c', {2,2}, {0, 1, 1, 1}, nd4j::DataType::BOOL); NDArray exp6('c', {2,2}, {0, 3, 6, 9}, nd4j::DataType::INT64); - x1.applyIndexedLambda(func1, &x4); + x1.applyIndexedLambda(func1, x4); ASSERT_EQ(x4, exp1); - x2.applyIndexedLambda(func1); + x2.applyIndexedLambda(func1, x2); ASSERT_EQ(x2, exp2); - x2.applyIndexedLambda(func2); + x2.applyIndexedLambda(func2, x2); ASSERT_EQ(x2, exp6); - x3.applyIndexedLambda(func3); + x3.applyIndexedLambda(func3, x3); ASSERT_EQ(x3, exp3); - x5.applyIndexedLambda(func4); + x5.applyIndexedLambda(func4, x5); ASSERT_EQ(x5, exp4); - x6.applyIndexedLambda(func5, &x7); + x6.applyIndexedLambda(func5, x7); ASSERT_EQ(x7, exp5); } @@ -1490,22 +1483,22 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) { NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); - x1.applyPairwiseLambda(&other2, func1, &x4); + x1.applyPairwiseLambda(other2, func1, x4); ASSERT_EQ(x4, exp1); - x2.applyPairwiseLambda(&other3, func1); + x2.applyPairwiseLambda(other3, func1, x2); ASSERT_EQ(x2, exp2); - x2.applyPairwiseLambda(&other3, func2); + x2.applyPairwiseLambda(other3, func2, x2); ASSERT_EQ(x2, other3); - x3.applyPairwiseLambda(&other1, func3); + x3.applyPairwiseLambda(other1, func3, x3); ASSERT_EQ(x3, exp3); - x5.applyPairwiseLambda(&other1, func4); + x5.applyPairwiseLambda(other1, func4, x5); ASSERT_EQ(x5, exp4); - x6.applyPairwiseLambda(&other4, func5, &x7); + x6.applyPairwiseLambda(other4, func5, x7); ASSERT_EQ(x7, exp5); } @@ -1536,22 +1529,22 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) { NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2,2}, {0, 1, 1, 1}, nd4j::DataType::BOOL); - x1.applyIndexedPairwiseLambda(&other2, func1, &x4); + x1.applyIndexedPairwiseLambda(other2, func1, x4); ASSERT_EQ(x4, exp1); - x2.applyIndexedPairwiseLambda(&other3, func1); + x2.applyIndexedPairwiseLambda(other3, func1, x2); ASSERT_EQ(x2, exp2); - x2.applyIndexedPairwiseLambda(&other3, func2); + x2.applyIndexedPairwiseLambda(other3, func2, x2); ASSERT_EQ(x2, exp2); - x3.applyIndexedPairwiseLambda(&other1, func3); + x3.applyIndexedPairwiseLambda(other1, func3, x3); ASSERT_EQ(x3, exp3); - x5.applyIndexedPairwiseLambda(&other1, func4); + x5.applyIndexedPairwiseLambda(other1, func4, x5); ASSERT_EQ(x5, exp4); - x6.applyIndexedPairwiseLambda(&other4, func5, &x7); + x6.applyIndexedPairwiseLambda(other4, func5, x7); ASSERT_EQ(x7, exp5); } @@ -1578,16 +1571,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) { NDArray exp('c', {2,2}, {1, 1, 0, 1}, nd4j::DataType::BOOL); - x1.applyTriplewiseLambda(&x2, &x3, func1, &x4); + x1.applyTriplewiseLambda(x2, x3, func1, x4); ASSERT_EQ(x4, x2); - x1.applyTriplewiseLambda(&x2, &x3, func2); + x1.applyTriplewiseLambda(x2, x3, func2, x1); ASSERT_EQ(x1, x3); - x5.applyTriplewiseLambda(&x6, &x7, func3); + x5.applyTriplewiseLambda(x6, x7, func3, x5); ASSERT_EQ(x5, x7); - x8.applyTriplewiseLambda(&x9, &x10, func4); + x8.applyTriplewiseLambda(x9, x10, func4, x8); ASSERT_EQ(x8, exp); } @@ -1601,18 +1594,14 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) { NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64); - NDArray* scalar = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {0,1}); - ASSERT_EQ(*scalar, exp1); + NDArray scalar = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {0,1}); + ASSERT_EQ(scalar, exp1); - NDArray* vec1 = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {1}); - ASSERT_EQ(*vec1, exp2); + NDArray vec1 = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {1}); + ASSERT_EQ(vec1, exp2); - NDArray* vec2 = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {0}); - ASSERT_EQ(*vec2, exp3); - - delete scalar; - delete vec1; - delete vec2; + NDArray vec2 = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {0}); + ASSERT_EQ(vec2, exp3); } ////////////////////////////////////////////////////////////////////////////// @@ -1626,13 +1615,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test2) { NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64); - x1.applyIndexReduce(nd4j::indexreduce::IndexMax, &scalar, {0,1}); + x1.applyIndexReduce(nd4j::indexreduce::IndexMax, scalar, {0,1}); ASSERT_EQ(scalar, exp1); - x1.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec1, {1}); + x1.applyIndexReduce(nd4j::indexreduce::IndexMax, vec1, {1}); ASSERT_EQ(vec1, exp2); - x1.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec2, {0}); + x1.applyIndexReduce(nd4j::indexreduce::IndexMax, vec2, {0}); ASSERT_EQ(vec2, exp3); } @@ -1646,13 +1635,11 @@ TEST_F(MultiDataTypeTests, applyReduce3_test1) { NDArray exp1('c', {}, {-30}, nd4j::DataType::FLOAT32); NDArray exp2('c', {}, {15}, nd4j::DataType::DOUBLE); - auto result = x1.applyReduce3(reduce3::Dot, &x2); - ASSERT_EQ(*result, exp1); - delete result; + auto result = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_EQ(result, exp1); - result = x3.applyReduce3(reduce3::Dot, &x4); - ASSERT_EQ(*result, exp2); - delete result; + result = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_EQ(result, exp2); } ////////////////////////////////////////////////////////////////////// @@ -1674,29 +1661,23 @@ TEST_F(MultiDataTypeTests, applyReduce3_test2) { NDArray exp5('c', {3}, {7.5,10.5,13.5}, nd4j::DataType::DOUBLE); NDArray exp6('c', {2}, {9,22.5}, nd4j::DataType::DOUBLE); - auto result = x1.applyReduce3(reduce3::Dot, &x2, {0,1}); - ASSERT_EQ(*result, exp1); - delete result; + auto result = x1.applyReduce3(reduce3::Dot, x2, {0,1}); + ASSERT_EQ(result, exp1); - result = x3.applyReduce3(reduce3::Dot, &x4, {0,1}); - ASSERT_EQ(*result, exp2); - delete result; + result = x3.applyReduce3(reduce3::Dot, x4, {0,1}); + ASSERT_EQ(result, exp2); - result = x5.applyReduce3(reduce3::Dot, &x6, std::vector({0})); - ASSERT_EQ(*result, exp3); - delete result; + result = x5.applyReduce3(reduce3::Dot, x6, std::vector({0})); + ASSERT_EQ(result, exp3); - result = x5.applyReduce3(reduce3::Dot, &x6, std::vector({1})); - ASSERT_EQ(*result, exp4); - delete result; + result = x5.applyReduce3(reduce3::Dot, x6, std::vector({1})); + ASSERT_EQ(result, exp4); - result = x8.applyReduce3(reduce3::Dot, &x7, std::vector({0})); - ASSERT_EQ(*result, exp5); - delete result; + result = x8.applyReduce3(reduce3::Dot, x7, std::vector({0})); + ASSERT_EQ(result, exp5); - result = x8.applyReduce3(reduce3::Dot, &x7, std::vector({1})); - ASSERT_EQ(*result, exp6); - delete result; + result = x8.applyReduce3(reduce3::Dot, x7, std::vector({1})); + ASSERT_EQ(result, exp6); } ////////////////////////////////////////////////////////////////////// @@ -1709,13 +1690,11 @@ TEST_F(MultiDataTypeTests, applyAllReduce3_test1) { NDArray exp1('c', {2,3}, {2,-2,2,2,-2,2}, nd4j::DataType::FLOAT32); NDArray exp2('c', {2,3}, {6,6,6,9,9,9}, nd4j::DataType::DOUBLE); - auto result = x1.applyAllReduce3(reduce3::Dot, &x2, {0}); - ASSERT_EQ(*result, exp1); - delete result; + auto result = x1.applyAllReduce3(reduce3::Dot, x2, {0}); + ASSERT_EQ(result, exp1); - result = x4.applyAllReduce3(reduce3::Dot, &x3, {0}); - ASSERT_EQ(*result, exp2); - delete result; + result = x4.applyAllReduce3(reduce3::Dot, x3, {0}); + ASSERT_EQ(result, exp2); } ////////////////////////////////////////////////////////////////////// @@ -1734,16 +1713,16 @@ TEST_F(MultiDataTypeTests, RowCol_test1) { NDArray exp3('c', {2,3}, {1.5,2.5,3.5,4.6,5.6,6.6}, nd4j::DataType::DOUBLE); NDArray exp4('c', {2,3}, {0,1,1,2,3,3}, nd4j::DataType::INT32); - x1.addiRowVector(&x3); + x1.addiRowVector(x3); ASSERT_EQ(x1, exp1); - x1.addiColumnVector(&x2); + x1.addiColumnVector(x2); ASSERT_EQ(x1, exp1); - x4.addiColumnVector(&x2); + x4.addiColumnVector(x2); ASSERT_EQ(x4, exp3); - x5.muliColumnVector(&x2); + x5.muliColumnVector(x2); ASSERT_EQ(x5, exp4); } @@ -1770,22 +1749,22 @@ TEST_F(MultiDataTypeTests, RowCol_test2) { NDArray exp5('c', {2,3}, {1,1,1,4,2.5,2}, nd4j::DataType::DOUBLE); NDArray exp6('c', {2,3}, {1.5,2.5,3.5,4.6,5.6,6.6}, nd4j::DataType::FLOAT32); - x1.addRowVector(&x3, &x4); + x1.addRowVector(x3, x4); ASSERT_EQ(x4, exp1); - x1.addRowVector(&x5, &x6); + x1.addRowVector(x5, x6); ASSERT_EQ(x6, exp2); - x8.subRowVector(&x7, &x4); + x8.subRowVector(x7, x4); ASSERT_EQ(x4, exp3); - x1.mulRowVector(&x9, &x10); + x1.mulRowVector(x9, x10); ASSERT_EQ(x10, exp4); - x1.divRowVector(&x9, &x10); + x1.divRowVector(x9, x10); ASSERT_EQ(x10, exp5); - x1.addColumnVector(&x2, &x4); + x1.addColumnVector(x2, x4); ASSERT_EQ(x4, exp6); } @@ -1826,25 +1805,6 @@ TEST_F(MultiDataTypeTests, tile_test1) { } */ -////////////////////////////////////////////////////////////////////// -TEST_F(MultiDataTypeTests, broadcast_test1) { - - NDArray x1('c', {2,1,3}, nd4j::DataType::INT32); - NDArray x2('c', {2,4,1}, nd4j::DataType::INT64); - NDArray x3('c', {2,4,1}, nd4j::DataType::DOUBLE); - - NDArray exp1('c', {2,4,3}, nd4j::DataType::INT32); - NDArray exp2('c', {2,4,3}, nd4j::DataType::DOUBLE); - - auto result = x1.broadcast(x2); - ASSERT_TRUE(result->isSameShapeStrict(&exp1)); - delete result; - - result = x1.broadcast(x3); - ASSERT_TRUE(result->isSameShapeStrict(&exp2)); - delete result; -} - ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, asT_test1) { @@ -1853,19 +1813,19 @@ TEST_F(MultiDataTypeTests, asT_test1) { NDArray exp1('c', {2}, {1, 2}, nd4j::DataType::INT32); NDArray exp2('c', {2}, {1.5, 2.5}, nd4j::DataType::DOUBLE); - auto result = x1.asT(); + auto result = new NDArray(x1.asT()); ASSERT_EQ(*result, exp1); delete result; - result = x1.asT(); + result = new NDArray(x1.asT()); ASSERT_EQ(*result, exp2); delete result; - result = x1.asT(nd4j::DataType::INT32); + result = new NDArray(x1.asT(nd4j::DataType::INT32)); ASSERT_EQ(*result, exp1); delete result; - result = x1.asT(nd4j::DataType::DOUBLE); + result = new NDArray(x1.asT(nd4j::DataType::DOUBLE)); ASSERT_EQ(*result, exp2); delete result; } @@ -1904,7 +1864,7 @@ TEST_F(MultiDataTypeTests, Test_Cast_1) { asBool.assign(first); // asBool.printIndexedBuffer("asBool"); - asBool.applyScalar(scalar::Not, false, &_not); + asBool.applyScalar(scalar::Not, false, _not); // _not.printIndexedBuffer("_not"); @@ -1925,7 +1885,7 @@ TEST_F(MultiDataTypeTests, Test_Cast_2) { asBool.assign(first); // asBool.printIndexedBuffer("asBool"); - asBool.applyTransform(transform::Not, &_not); + asBool.applyTransform(transform::Not, _not); // _not.printIndexedBuffer("_not"); @@ -1968,7 +1928,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) { } try { - x1.divRowVector(&x4, &x3); + x1.divRowVector(x4, x3); } catch (std::exception& message) { // printf("%s\n", message.what()); @@ -1976,7 +1936,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) { } try { - x1.applyBroadcast(nd4j::broadcast::FloorDiv, {1}, &x4, &x3); + x1.applyBroadcast(nd4j::broadcast::FloorDiv, {1}, x4, x3); } catch (std::exception& message) { // printf("%s\n", message.what()); @@ -1984,7 +1944,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) { } try { - x1.applyTrueBroadcast(BROADCAST(FloorMod), &x2, &x3, true); + x1.applyTrueBroadcast(BROADCAST(FloorMod), x2, x3); } catch (std::exception& message) { // printf("%s\n", message.what()); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu index 7740cd1ac..c6c0a1bd8 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu @@ -125,7 +125,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_Registration_03) { ASSERT_FALSE(x->isActualOnHostSide()); NDArray::registerSpecialUse({y}, {x}); - x->applyTransform(transform::Neg, y, nullptr); + x->applyTransform(transform::Neg, *y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -145,7 +145,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_Cosine_1) { ASSERT_FALSE(x->isActualOnHostSide()); NDArray::registerSpecialUse({y}, {x}); - x->applyTransform(transform::Cosine, y, nullptr); + x->applyTransform(transform::Cosine, *y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -247,8 +247,10 @@ TEST_F(NDArrayCudaBasicsTests, TestAdd_3) { auto res = cudaStreamSynchronize(*stream); ASSERT_EQ(0, res); //double* localBuffer = ; + z.syncToHost(); cudaMemcpy(z.buffer(), z.specialBuffer(), z.lengthOf() * z.sizeOfT(), cudaMemcpyDeviceToHost); res = cudaStreamSynchronize(*stream); + z.tickWriteHost(); ASSERT_EQ(0, res); // @@ -278,7 +280,7 @@ TEST_F(NDArrayCudaBasicsTests, TestAdd_4) { //ASSERT_EQ(0, res); //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Add, y, z); // // cudaFree(devBufferPtrX); @@ -400,7 +402,7 @@ TEST_F(NDArrayCudaBasicsTests, TestMultiply_1) { //ASSERT_EQ(0, res); //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Multiply, y, z); // x.printBuffer("3X = "); // y.printBuffer("3Y = "); // z.printBuffer("3Result out"); @@ -432,7 +434,7 @@ TEST_F(NDArrayCudaBasicsTests, TestMultiply_2) { //ASSERT_EQ(0, res); //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Multiply, y, z); // // cudaFree(devBufferPtrX); @@ -461,7 +463,7 @@ TEST_F(NDArrayCudaBasicsTests, TestMultiply_3) { //ASSERT_EQ(0, res); //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Multiply, y, z); //x.printBuffer("23X = "); //y.printBuffer("23Y = "); // z.printBuffer("23Result out"); @@ -539,7 +541,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveNeg_2) { ASSERT_TRUE(x.isActualOnDeviceSide()); ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Neg, &y, nullptr); + x.applyTransform(transform::Neg, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -559,7 +561,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveSqrt_1) { // strict ASSERT_TRUE(x.isActualOnDeviceSide()); ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Sqrt, &y, nullptr); + x.applyTransform(transform::Sqrt, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -580,7 +582,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveAssign_1) { // strict //ASSERT_TRUE(x.isActualOnDeviceSide()); //ASSERT_TRUE(x.isActualOnHostSide()); - x.applyTransform(transform::Assign, &y, nullptr); + x.applyTransform(transform::Assign, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -606,7 +608,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_1) { // strict ASSERT_TRUE(x.isActualOnDeviceSide()); ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Cosine, &y, nullptr); + x.applyTransform(transform::Cosine, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -628,7 +630,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_2) { ASSERT_TRUE(x.isActualOnDeviceSide()); ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Cosine, &y, nullptr); + x.applyTransform(transform::Cosine, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -657,7 +659,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_3) { ASSERT_TRUE(x.isActualOnDeviceSide()); ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Cosine, &y, nullptr); + x.applyTransform(transform::Cosine, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -857,7 +859,7 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_01) { //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); //x.printBuffer("23X = "); //y.printBuffer("23Y = "); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);// *= y; + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);// *= y; // z.printBuffer("53Result out"); // @@ -890,7 +892,7 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_02) { //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); //x.printBuffer("23X = "); //y.printBuffer("23Y = "); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);// *= y; + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);// *= y; // z.printBuffer("52Result out"); @@ -924,7 +926,7 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_002) { //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); //x.printBuffer("23X = "); //y.printBuffer("23Y = "); - x.applyPairwiseTransform(pairwise::Multiply, &y, &z);// *= y; + x.applyPairwiseTransform(pairwise::Multiply, y, z);// *= y; // z.printBuffer("51Result out"); @@ -1059,7 +1061,7 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_2) { //x.printBuffer("23X = "); //y.printBuffer("23Y = "); //void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &exp); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, exp); // // cudaFree(devBufferPtrX); @@ -1106,10 +1108,7 @@ TEST_F(NDArrayCudaBasicsTests, TestDup1) { ASSERT_TRUE(array.equalsTo(arrF)); ASSERT_TRUE(array.equalsTo(arrC)); - ASSERT_TRUE(arrF->equalsTo(arrC)); - - delete arrC; - delete arrF; + ASSERT_TRUE(arrF.equalsTo(arrC)); } ////////////////////////////////////////////////////////////////////////// @@ -1169,27 +1168,22 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_1) { NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, nd4j::DataType::FLOAT32); - NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2}); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + NDArray z = x.applyReduce3(nd4j::reduce3::Dot, y, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x.applyReduce3(nd4j::reduce3::Dot, &k, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp3)); - delete z; + z = x.applyReduce3(nd4j::reduce3::Dot, k, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp3)); x.permutei({0,2,1}); y.permutei({0,2,1}); - z = y.applyReduce3(nd4j::reduce3::Dot, &x, {1}); - ASSERT_TRUE(z->equalsTo(&exp2)); - // printCudaGlobal<<<1,1,0, *y.getContext()->getCudaStream()>>>(z->specialBuffer(), 6); - delete z; + z = y.applyReduce3(nd4j::reduce3::Dot, x, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); x2.permutei({1,0,2}); - z = x2.applyReduce3(nd4j::reduce3::Dot, &k2, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp4)); - delete z; + z = x2.applyReduce3(nd4j::reduce3::Dot, k2, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp4)); } //////////////////////////////////////////////////////////////////////////// @@ -1206,27 +1200,22 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_2) { NDArray exp3('c', {4}, {39., 42.5, 47., 49.5}, nd4j::DataType::DOUBLE); NDArray exp4('c', {4}, {119., 122.5, 125., 129.5}, nd4j::DataType::DOUBLE); - NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2}); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + NDArray z = x.applyReduce3(nd4j::reduce3::Dot, y, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x.applyReduce3(nd4j::reduce3::Dot, &k, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp3)); - delete z; + z = x.applyReduce3(nd4j::reduce3::Dot, k, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp3)); x.permutei({0,2,1}); y.permutei({0,2,1}); - z = y.applyReduce3(nd4j::reduce3::Dot, &x, {1}); - ASSERT_TRUE(z->equalsTo(&exp2)); - // printCudaGlobal<<<1,1,0, *y.getContext()->getCudaStream()>>>(z->specialBuffer(), 6); - delete z; + z = y.applyReduce3(nd4j::reduce3::Dot, x, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); x2.permutei({1,0,2}); - z = x2.applyReduce3(nd4j::reduce3::Dot, &k2, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp4)); - delete z; + z = x2.applyReduce3(nd4j::reduce3::Dot, k2, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp4)); } //////////////////////////////////////////////////////////////////////////// @@ -1241,26 +1230,22 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_3) { NDArray exp2('c', {}, {31.5}, nd4j::DataType::DOUBLE); - auto z = x1.applyReduce3(reduce3::Dot, &x2); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + auto z = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x3.applyReduce3(reduce3::Dot, &x4); - ASSERT_TRUE(z->equalsTo(&exp2)); - delete z; + z = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_TRUE(z.equalsTo(&exp2)); x1.permutei({2,1,0}); x2.permutei({2,1,0}); x3.permutei({1,0}); x4.permutei({1,0}); - z = x1.applyReduce3(reduce3::Dot, &x2); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + z = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x3.applyReduce3(reduce3::Dot, &x4); - ASSERT_TRUE(z->equalsTo(&exp2)); - delete z; + z = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_TRUE(z.equalsTo(&exp2)); } //////////////////////////////////////////////////////////////////////////// @@ -1278,37 +1263,28 @@ TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) { NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE); - auto z = x1.applyAllReduce3(reduce3::Dot, &x2, {0,2}); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x1.applyAllReduce3(reduce3::Dot, &x2, {0}); - ASSERT_TRUE(z->equalsTo(&exp2)); - delete z; + z = x1.applyAllReduce3(reduce3::Dot, x2, {0}); + ASSERT_TRUE(z.equalsTo(&exp2)); - z = x3.applyAllReduce3(reduce3::Dot, &x4, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp3)); - delete z; + z = x3.applyAllReduce3(reduce3::Dot, x4, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp3)); - z = x3.applyAllReduce3(reduce3::Dot, &x4, {1}); - // z->syncToHost(); - // z->printShapeInfo(); - // z->printIndexedBuffer(); - ASSERT_TRUE(z->equalsTo(&exp4)); - delete z; + z = x3.applyAllReduce3(reduce3::Dot, x4, {1}); + ASSERT_TRUE(z.equalsTo(&exp4)); x1.permutei({2,1,0}); x2.permutei({2,1,0}); x3.permutei({1,0}); x4.permutei({1,0}); - z = x1.applyAllReduce3(reduce3::Dot, &x2, {0,2}); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x3.applyAllReduce3(reduce3::Dot, &x4, {0}); - ASSERT_TRUE(z->equalsTo(&exp4)); - delete z; + z = x3.applyAllReduce3(reduce3::Dot, x4, {0}); + ASSERT_TRUE(z.equalsTo(&exp4)); } ////////////////////////////////////////////////////////////////////////////// @@ -1328,24 +1304,24 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test1) { NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64); NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &scalar, {0,1}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, scalar, {0,1}); ASSERT_TRUE(scalar.equalsTo(&exp1)); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec1, {1}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, vec1, {1}); ASSERT_TRUE(vec1.equalsTo(&exp2)); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec2, {0}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, vec2, {0}); ASSERT_TRUE(vec2.equalsTo(&exp3)); x.permutei({1,0}); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &scalar, {0,1}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, scalar, {0,1}); ASSERT_TRUE(scalar.equalsTo(&exp4)); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec1, {0}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, vec1, {0}); ASSERT_TRUE(vec1.equalsTo(&exp5)); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec2, {1}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, vec2, {1}); ASSERT_TRUE(vec2.equalsTo(&exp6)); } @@ -1364,30 +1340,24 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test2) { NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64); auto z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp1)); z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {1}); - ASSERT_TRUE(z->equalsTo(&exp2)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp2)); z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {0}); - ASSERT_TRUE(z->equalsTo(&exp3)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp3)); x.permutei({1,0}); z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp4)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp4)); z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {0}); - ASSERT_TRUE(z->equalsTo(&exp5)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp5)); z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {1}); - ASSERT_TRUE(z->equalsTo(&exp6)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp6)); } //////////////////////////////////////////////////////////////////////////////// @@ -1407,24 +1377,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) { NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2}, {3.5f,0.833333f}, nd4j::DataType::FLOAT32); - x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::Mean, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::Mean, &z2, {1}); + x.reduceAlongDimension(nd4j::reduce::Mean, z2, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(nd4j::reduce::Mean, &z3, {0,2}); + x.reduceAlongDimension(nd4j::reduce::Mean, z3, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::Mean, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::Mean, &z4, {1}); + x.reduceAlongDimension(nd4j::reduce::Mean, z4, {1}); ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(nd4j::reduce::Mean, &z5, {0,2}); + x.reduceAlongDimension(nd4j::reduce::Mean, z5, {0,2}); ASSERT_TRUE(z5.equalsTo(&exp5)); } @@ -1439,24 +1409,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test2) { NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::DOUBLE); NDArray exp5('c', {2}, {3.5,0.833333}, nd4j::DataType::DOUBLE); - NDArray z1 = x.reduceAlongDims(nd4j::reduce::Mean, {0,1,2}); + NDArray z1 = x.reduceAlongDimension(nd4j::reduce::Mean, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDims(nd4j::reduce::Mean, {1}); + NDArray z2 = x.reduceAlongDimension(nd4j::reduce::Mean, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDims(nd4j::reduce::Mean, {0,2}); + NDArray z3 = x.reduceAlongDimension(nd4j::reduce::Mean, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - NDArray z4 = x.reduceAlongDims(nd4j::reduce::Mean, {0,1,2}); + NDArray z4 = x.reduceAlongDimension(nd4j::reduce::Mean, {0,1,2}); ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDims(nd4j::reduce::Mean, {1}); + NDArray z5 = x.reduceAlongDimension(nd4j::reduce::Mean, {1}); ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDims(nd4j::reduce::Mean, {0,2}); + NDArray z6 = x.reduceAlongDimension(nd4j::reduce::Mean, {0,2}); ASSERT_TRUE(z6.equalsTo(&exp5)); } @@ -1519,24 +1489,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) { NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2}, {21.5f,5.f}, nd4j::DataType::FLOAT32); - x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::Sum, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::Sum, &z2, {1}); + x.reduceAlongDimension(nd4j::reduce::Sum, z2, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(nd4j::reduce::Sum, &z3, {0,2}); + x.reduceAlongDimension(nd4j::reduce::Sum, z3, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::Sum, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::Sum, &z4, {1}); + x.reduceAlongDimension(nd4j::reduce::Sum, z4, {1}); ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(nd4j::reduce::Sum, &z5, {0,2}); + x.reduceAlongDimension(nd4j::reduce::Sum, z5, {0,2}); ASSERT_TRUE(z5.equalsTo(&exp5)); } @@ -1551,24 +1521,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test2) { NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, nd4j::DataType::INT64); NDArray exp5('c', {2}, {21,5}, nd4j::DataType::INT64); - NDArray z1 = x.reduceAlongDims(nd4j::reduce::Sum, {0,1,2}); + NDArray z1 = x.reduceAlongDimension(nd4j::reduce::Sum, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDims(nd4j::reduce::Sum, {1}); + NDArray z2 = x.reduceAlongDimension(nd4j::reduce::Sum, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDims(nd4j::reduce::Sum, {0,2}); + NDArray z3 = x.reduceAlongDimension(nd4j::reduce::Sum, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - NDArray z4 = x.reduceAlongDims(nd4j::reduce::Sum, {0,1,2}); + NDArray z4 = x.reduceAlongDimension(nd4j::reduce::Sum, {0,1,2}); ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDims(nd4j::reduce::Sum, {1}); + NDArray z5 = x.reduceAlongDimension(nd4j::reduce::Sum, {1}); ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDims(nd4j::reduce::Sum, {0,2}); + NDArray z6 = x.reduceAlongDimension(nd4j::reduce::Sum, {0,2}); ASSERT_TRUE(z6.equalsTo(&exp5)); } @@ -1589,24 +1559,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) { NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL); NDArray exp5('c', {2}, {true,true}, nd4j::DataType::BOOL); - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z2, {1}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z2, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z3, {0,2}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z3, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z4, {1}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z4, {1}); ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z5, {0,2}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z5, {0,2}); ASSERT_TRUE(z5.equalsTo(&exp5)); } @@ -1621,24 +1591,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) { NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, nd4j::DataType::BOOL); NDArray exp5('c', {2}, {1,1}, nd4j::DataType::BOOL); - NDArray z1 = x.reduceAlongDims(nd4j::reduce::IsPositive, {0,1,2}); + NDArray z1 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDims(nd4j::reduce::IsPositive, {1}); + NDArray z2 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDims(nd4j::reduce::IsPositive, {0,2}); + NDArray z3 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - NDArray z4 = x.reduceAlongDims(nd4j::reduce::IsPositive, {0,1,2}); + NDArray z4 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {0,1,2}); ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDims(nd4j::reduce::IsPositive, {1}); + NDArray z5 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {1}); ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDims(nd4j::reduce::IsPositive, {0,2}); + NDArray z6 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {0,2}); ASSERT_TRUE(z6.equalsTo(&exp5)); } @@ -1659,24 +1629,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) { NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, nd4j::DataType::INT64); NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64); - x.reduceAlongDimension(nd4j::reduce::CountZero, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::CountZero, &z2, {1}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z2, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(nd4j::reduce::CountZero, &z3, {0,2}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z3, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - x.reduceAlongDimension(nd4j::reduce::CountZero, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::CountZero, &z4, {1}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z4, {1}); ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(nd4j::reduce::CountZero, &z5, {0,2}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z5, {0,2}); ASSERT_TRUE(z5.equalsTo(&exp5)); } @@ -1691,24 +1661,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test2) { NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, nd4j::DataType::INT64); NDArray exp5('c', {2}, {2,2}, nd4j::DataType::INT64); - NDArray z1 = x.reduceAlongDims(nd4j::reduce::CountZero, {0,1,2}); + NDArray z1 = x.reduceAlongDimension(nd4j::reduce::CountZero, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDims(nd4j::reduce::CountZero, {1}); + NDArray z2 = x.reduceAlongDimension(nd4j::reduce::CountZero, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDims(nd4j::reduce::CountZero, {0,2}); + NDArray z3 = x.reduceAlongDimension(nd4j::reduce::CountZero, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - NDArray z4 = x.reduceAlongDims(nd4j::reduce::CountZero, {0,1,2}); + NDArray z4 = x.reduceAlongDimension(nd4j::reduce::CountZero, {0,1,2}); ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDims(nd4j::reduce::CountZero, {1}); + NDArray z5 = x.reduceAlongDimension(nd4j::reduce::CountZero, {1}); ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDims(nd4j::reduce::CountZero, {0,2}); + NDArray z6 = x.reduceAlongDimension(nd4j::reduce::CountZero, {0,2}); ASSERT_TRUE(z6.equalsTo(&exp5)); } @@ -1722,7 +1692,7 @@ TEST_F(NDArrayCudaBasicsTests, BroadcastOpsTest1) { ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, row, &z, nullptr); + x.applyBroadcast(broadcast::Add, {1}, *row, z); x += *row; ASSERT_TRUE(x.equalsTo(z)); @@ -1740,7 +1710,7 @@ TEST_F(NDArrayCudaBasicsTests, BroadcastOpsTest2) { NDArray exp('c', {5,5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, nd4j::DataType::FLOAT32); ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, row); + x.applyBroadcast(broadcast::Add, {1}, *row, x); ASSERT_TRUE(x.equalsTo(&exp)); } @@ -1753,7 +1723,7 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcast_1) { auto bias = NDArrayFactory::create('c', {1, 3}); bias.linspace(1); - input.applyBroadcast(broadcast::Add, {1}, &bias); + input.applyBroadcast(broadcast::Add, {1}, bias, input); ASSERT_TRUE(exp.equalsTo(&input)); } @@ -1807,7 +1777,7 @@ TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_05) expected = 3.; res2 = 0.f; - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, res2);// *= y; ASSERT_TRUE(expected.isSameShape(&res2)); ASSERT_TRUE(expected.equalsTo(&res2)); @@ -2095,20 +2065,20 @@ TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) { auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); auto diag = x.diagonal('c'); - //diag->syncToDevice(); + //diag.syncToDevice(); for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { - printf("VAL[%ld] = %f\n", e, diag->e(e)); //, exp.e(e), 1.e-5); + printf("VAL[%ld] = %f\n", e, diag.e(e)); //, exp.e(e), 1.e-5); } for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { - ASSERT_NEAR(diag->e(e), exp.e(e), 1.e-5); + ASSERT_NEAR(diag.e(e), exp.e(e), 1.e-5); } double eps(1.e-5); NDArray tmp(nd4j::DataType::FLOAT32, x.getContext()); // scalar = 0 ExtraArguments extras({eps}); - NativeOpExecutioner::execReduce3Scalar(diag->getContext(), reduce3::EqualsWithEps, diag->getBuffer(), - diag->getShapeInfo(), diag->getSpecialBuffer(), diag->getSpecialShapeInfo(), extras.argumentsAsT(nd4j::DataType::FLOAT32), + NativeOpExecutioner::execReduce3Scalar(diag.getContext(), reduce3::EqualsWithEps, diag.getBuffer(), + diag.getShapeInfo(), diag.getSpecialBuffer(), diag.getSpecialShapeInfo(), extras.argumentsAsT(nd4j::DataType::FLOAT32), exp.getBuffer(), exp.getShapeInfo(), exp.getSpecialBuffer(), exp.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo()); cudaStream_t* stream = x.getContext()->getCudaStream(); @@ -2116,8 +2086,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) { // tmp.printBuffer("Compare result is (expected 0)"); ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) { diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp index cc8549e81..e57c7e625 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp @@ -36,7 +36,7 @@ TEST_F(NDArrayListTests, BasicTests_1) { auto x = NDArrayFactory::create('c', {1, 10}); auto y = NDArrayFactory::create('c', {1, 10}); - ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup()))); //ASSERT_EQ(ND4J_STATUS_DOUBLE_WRITE, list.write(1, &y)); } @@ -47,7 +47,7 @@ TEST_F(NDArrayListTests, BasicTests_2) { auto x = NDArrayFactory::create('c', {1, 10}); auto y = NDArrayFactory::create('c', {1, 7}); - ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup()))); ASSERT_EQ(ND4J_STATUS_BAD_INPUT, list.write(0, &y)); } @@ -63,7 +63,7 @@ TEST_F(NDArrayListTests, Test_Stack_UnStack_1) { ASSERT_EQ(10, list.elements()); - auto array = list.stack(); + auto array = list.stack(); ASSERT_TRUE(input.isSameShape(array)); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp index d0fb4bf37..fb55b4484 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp @@ -52,8 +52,8 @@ TEST_F(NDArrayTest, TestDup1) { NDArray array(arr1, shape1); - auto arrC = array.dup('c'); - auto arrF = array.dup('f'); + auto arrC = new NDArray(array.dup('c')); + auto arrF = new NDArray(array.dup('f')); ASSERT_TRUE(array.equalsTo(arrF)); ASSERT_TRUE(array.equalsTo(arrC)); @@ -87,8 +87,8 @@ TEST_F(NDArrayTest, NDArrayOrder1) { auto f = new float[4] {1, 3, 2, 4}; auto arrayC = new NDArray(c, cShape); - auto arrayF = arrayC->dup('f'); - auto arrayC2 = arrayF->dup('c'); + auto arrayF = new NDArray(arrayC->dup('f')); + auto arrayC2 = new NDArray(arrayF->dup('c')); ASSERT_EQ('c', arrayC->ordering()); ASSERT_EQ('f', arrayF->ordering()); @@ -128,7 +128,7 @@ TEST_F(NDArrayTest, TestGetScalar1) { ASSERT_NEAR(3.0f, arrayC->e(1, 0), 1e-5f); ASSERT_NEAR(4.0f, arrayC->e(1, 1), 1e-5f); - auto arrayF = arrayC->dup('f'); + auto arrayF = new NDArray(arrayC->dup('f')); ASSERT_NEAR(3.0f, arrayF->e(1, 0), 1e-5f); ASSERT_NEAR(4.0f, arrayF->e(1, 1), 1e-5f); @@ -199,14 +199,12 @@ TEST_F(NDArrayTest, TestTad1) { auto row2 = array->tensorAlongDimension(1, {1}); - ASSERT_TRUE(row2->isView()); - ASSERT_EQ(3, row2->lengthOf()); + ASSERT_TRUE(row2.isView()); + ASSERT_EQ(3, row2.lengthOf()); - row2->assign(1.0); + row2.assign(1.0); ASSERT_NEAR(3.0f, array->sumNumber().e(0), 1e-5); - - delete row2; delete array; } @@ -225,18 +223,8 @@ TEST_F(NDArrayTest, TestTad3) { auto row2 = array->tensorAlongDimension(1, {1}); - ASSERT_TRUE(row2->isView()); - ASSERT_EQ(3, row2->lengthOf()); - - row2->p(1, 1.0); - - //array->printBuffer(); - - row2->p(2, 1.0); - - //array->printBuffer(); - - delete row2; + ASSERT_TRUE(row2.isView()); + ASSERT_EQ(3, row2.lengthOf()); delete array; } @@ -296,17 +284,14 @@ TEST_F(NDArrayTest, TestRepeat1) { auto rep = array.repeat(0, {2}); - ASSERT_EQ(4, rep->sizeAt(0)); - ASSERT_EQ(2, rep->sizeAt(1)); - - // rep->printIndexedBuffer("Repeated"); + ASSERT_EQ(4, rep.sizeAt(0)); + ASSERT_EQ(2, rep.sizeAt(1)); ASSERT_TRUE(exp->equalsTo(rep)); delete[] eBuffer; delete[] eShape; delete exp; - delete rep; } ////////////////////////////////////////////////////////////////////// @@ -320,7 +305,7 @@ TEST_F(NDArrayTest, TestRepeat2) { //array->printBuffer(); - auto rep = exp->dup(); + auto rep = new NDArray(exp->dup()); rep->assign(0.); array->repeat(0, {2}, *rep); //rep->printIndexedBuffer("Repeated"); @@ -374,7 +359,7 @@ TEST_F(NDArrayTest, TestAddiRowVector) { auto exp = new NDArray(e, cShape); row->assign(1.0f); - array->addiRowVector(row); + array->addiRowVector(*row); ASSERT_TRUE(exp->equalsTo(array)); @@ -397,8 +382,8 @@ TEST_F(NDArrayTest, TestAddiColumnVector) { NDArray column(arr2, shape2); NDArray exp(arr3, shape1); - matrix.addiColumnVector(&column); - ASSERT_TRUE(exp.isSameShapeStrict(&matrix)); + matrix.addiColumnVector(column); + ASSERT_TRUE(exp.isSameShapeStrict(matrix)); ASSERT_TRUE(exp.equalsTo(&matrix)); } @@ -414,9 +399,9 @@ TEST_F(NDArrayTest, TestMuliColumnVector) { NDArray column(arr2, shape2); NDArray exp(arr3, shape1); - matrix.muliColumnVector(&column); + matrix.muliColumnVector(column); - ASSERT_TRUE(exp.isSameShapeStrict(&matrix)); + ASSERT_TRUE(exp.isSameShapeStrict(matrix)); ASSERT_TRUE(exp.equalsTo(&matrix)); } @@ -478,7 +463,7 @@ TEST_F(NDArrayTest, TestSumAlongDimension1) { NDArray array('c', {2,2}, {1,2,3,4}, nd4j::DataType::FLOAT32); - auto res = array.reduceAlongDims(reduce::Sum, {0}); + auto res = array.reduceAlongDimension(reduce::Sum, {0}); ASSERT_EQ(2, res.lengthOf()); @@ -493,14 +478,13 @@ TEST_F(NDArrayTest, TestSumAlongDimension2) { auto res = array->reduceAlongDimension(reduce::Sum, {1}); - ASSERT_EQ(2, res->lengthOf()); + ASSERT_EQ(2, res.lengthOf()); - ASSERT_EQ(3.0f, res->e(0)); - ASSERT_EQ(7.0f, res->e(1)); + ASSERT_EQ(3.0f, res.e(0)); + ASSERT_EQ(7.0f, res.e(1)); delete[] c; delete array; - delete res; } ////////////////////////////////////////////////////////////////////// @@ -508,18 +492,15 @@ TEST_F(NDArrayTest, TestReduceAlongDimension1) { float *c = new float[4] {1, 2, 3, 4}; auto array = new NDArray(c, cShape); - auto exp = array->reduceAlongDimension(reduce::Sum, {1}); auto res = array->reduceAlongDimension(reduce::Sum, {1}); - ASSERT_EQ(2, res->lengthOf()); + ASSERT_EQ(2, res.lengthOf()); - ASSERT_EQ(3.0f, res->e(0)); - ASSERT_EQ(7.0f, res->e(1)); + ASSERT_EQ(3.0f, res.e(0)); + ASSERT_EQ(7.0f, res.e(1)); delete[] c; delete array; - delete exp; - delete res; } ////////////////////////////////////////////////////////////////////// @@ -530,7 +511,7 @@ TEST_F(NDArrayTest, TestTransform1) { float *e = new float[4] {1, 2, 3, 4}; auto exp = new NDArray(e, cShape); - array->applyTransform(transform::Abs, nullptr, nullptr); + array->applyTransform(transform::Abs, *array); ASSERT_TRUE(exp->equalsTo(array)); @@ -579,7 +560,7 @@ TEST_F(NDArrayTest, TestApplyTransform1) { float *e = new float[4] {1, 2, 3, 4}; auto exp = new NDArray(e, cShape); - array->applyTransform(transform::Abs, nullptr, nullptr); + array->applyTransform(transform::Abs, *array); ASSERT_TRUE(exp->equalsTo(array)); @@ -668,20 +649,17 @@ TEST_F(NDArrayTest, TestReductionAny1) { array.syncToDevice(); auto result0 = array.reduceAlongDimension(reduce::Any, {0}); - ASSERT_EQ(2, result0->lengthOf()); + ASSERT_EQ(2, result0.lengthOf()); - ASSERT_NEAR(1.0f, result0->e(0), 1e-5f); - ASSERT_NEAR(1.0f, result0->e(1), 1e-5f); + ASSERT_NEAR(1.0f, result0.e(0), 1e-5f); + ASSERT_NEAR(1.0f, result0.e(1), 1e-5f); auto result1 = array.reduceAlongDimension(reduce::Any, {1}); - ASSERT_EQ(2, result1->lengthOf()); + ASSERT_EQ(2, result1.lengthOf()); - ASSERT_NEAR(1.0f, result1->e(0), 1e-5f); - ASSERT_NEAR(0.0f, result1->e(1), 1e-5f); - - delete result0; - delete result1; + ASSERT_NEAR(1.0f, result1.e(0), 1e-5f); + ASSERT_NEAR(0.0f, result1.e(1), 1e-5f); } TEST_F(NDArrayTest, TestReductionAll1) { @@ -694,17 +672,14 @@ TEST_F(NDArrayTest, TestReductionAll1) { auto result0 = array.reduceAlongDimension(reduce::All, {0}); auto result1 = array.reduceAlongDimension(reduce::All, {1}); - ASSERT_EQ(2, result0->lengthOf()); - ASSERT_EQ(2, result1->lengthOf()); + ASSERT_EQ(2, result0.lengthOf()); + ASSERT_EQ(2, result1.lengthOf()); - ASSERT_FALSE(result0->e(0)); - ASSERT_FALSE(result0->e(1)); + ASSERT_FALSE(result0.e(0)); + ASSERT_FALSE(result0.e(1)); - ASSERT_TRUE(result1->e(0)); - ASSERT_FALSE(result1->e(1)); - - delete result0; - delete result1; + ASSERT_TRUE(result1.e(0)); + ASSERT_FALSE(result1.e(1)); } ////////////////////////////////////////////////////////////////////// @@ -728,7 +703,7 @@ TEST_F(NDArrayTest, TestTile1) { NDArray array1(arr1,shape1); // {2,3} NDArray array2(arr2,shape2); // {2,4,6} - auto expA = array1.dup('c'); + auto expA = new NDArray(array1.dup('c')); auto tiled = array1.tile(tileShape1); @@ -766,7 +741,7 @@ TEST_F(NDArrayTest, TestTile3) { array1.tilei(tileShape1); - ASSERT_TRUE(array1.isSameShapeStrict(&array2)); + ASSERT_TRUE(array1.isSameShapeStrict(array2)); ASSERT_TRUE(array1.equalsTo(&array2)); } @@ -781,7 +756,7 @@ TEST_F(NDArrayTest, TestTile4) { auto result = x.tile({2,1}); - ASSERT_TRUE(result.isSameShapeStrict(&exp)); + ASSERT_TRUE(result.isSameShapeStrict(exp)); ASSERT_TRUE(result.equalsTo(&exp)); } @@ -796,7 +771,7 @@ TEST_F(NDArrayTest, TestTile5) { auto result = x.tile({2,1}); - ASSERT_TRUE(result.isSameShapeStrict(&exp)); + ASSERT_TRUE(result.isSameShapeStrict(exp)); ASSERT_TRUE(result.equalsTo(&exp)); } @@ -881,8 +856,8 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul2) { for (int e = 0; e < y.lengthOf(); e++) y.p(e, e+1); - auto x_ = x.dup('f'); - auto y_ = y.dup('f'); + auto x_ = new NDArray(x.dup('f')); + auto y_ = new NDArray(y.dup('f')); x_->permutei({1, 0}); y_->permutei({1, 0}); @@ -940,7 +915,7 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul4) { for (int e = 0; e < y.lengthOf(); e++) y.p(e, e+1); - auto y_ = y.dup('f'); + auto y_ = new NDArray(y.dup('f')); x.permutei({0, 3, 4, 5, 1, 2}); y_->permutei({3, 2, 1, 0}); @@ -1264,7 +1239,7 @@ TEST_F(NDArrayTest, Permute1) { NDArray arr2(shape2,true); auto result = arr1.permute(perm); - ASSERT_TRUE(result.isSameShapeStrict(&arr2)); + ASSERT_TRUE(result.isSameShapeStrict(arr2)); } ////////////////////////////////////////////////////////////////////// @@ -1279,33 +1254,16 @@ TEST_F(NDArrayTest, Permute2) { NDArray arr2(shape2,true); ASSERT_TRUE(arr1.permutei(perm)); - ASSERT_TRUE(arr1.isSameShapeStrict(&arr2)); + ASSERT_TRUE(arr1.isSameShapeStrict(arr2)); } -////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Broadcast1) { - - Nd4jLong shape1[10] = {3, 5, 1, 10, 10, 10, 1, 8192, 1, 99}; - Nd4jLong shape2[8] = {2, 7, 10, 10, 1, 8192, 1, 99}; - Nd4jLong shape3[10] = {3, 5, 7, 10, 70, 10, 1, 8192, 1, 99}; - - NDArray arr1(shape1); - NDArray arr2(shape2); - NDArray arr3(shape3); - - auto result = arr1.broadcast(arr2); - ASSERT_TRUE(result->isSameShapeStrict(&arr3)); - delete result; -} - - TEST_F(NDArrayTest, RSubScalarTest1) { auto array = NDArrayFactory::create('c', {1, 4}); array.assign(2.0); auto result = NDArrayFactory::create('c', {1, 4}); - array.applyScalar(scalar::ReverseSubtract, 1.0, &result); + array.applyScalar(scalar::ReverseSubtract, 1.0, result); ASSERT_NEAR(-1.0, result.meanNumber().e(0), 1e-5); } @@ -1324,7 +1282,7 @@ TEST_F(NDArrayTest, BroadcastOpsTest1) { ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, row); + x.applyBroadcast(broadcast::Add, {1}, *row, x); //x.printBuffer("Result"); @@ -1374,9 +1332,9 @@ TEST_F(NDArrayTest, TestIndexedPut5) { TEST_F(NDArrayTest, TestAllTensors1) { auto matrix = NDArrayFactory::create('c', {3, 5}); - std::unique_ptr rows(matrix.allTensorsAlongDimension({1})); + ResultSet rows = matrix.allTensorsAlongDimension({1}); - ASSERT_EQ(3, rows->size()); + ASSERT_EQ(3, rows.size()); } @@ -1573,17 +1531,15 @@ TEST_F(NDArrayTest, TestStdDev2) { auto array = NDArrayFactory::create('c', {5, 6}); auto tad = array.tensorAlongDimension(0, {0}); - ASSERT_EQ(5, tad->lengthOf()); + ASSERT_EQ(5, tad.lengthOf()); - for (int e = 0; e < tad->lengthOf(); e++) - tad->p(e, e+1); + for (int e = 0; e < tad.lengthOf(); e++) + tad.p(e, e+1); - ASSERT_NEAR(15, tad->sumNumber().e(0), 1e-5); + ASSERT_NEAR(15, tad.sumNumber().e(0), 1e-5); - auto std = tad->varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + auto std = tad.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(std, 1.58109, 1e-4); - - delete tad; } TEST_F(NDArrayTest, TestStdDev3) { @@ -1654,8 +1610,6 @@ TEST_F(NDArrayTest, TestApplyIndexReduce1) { auto result = x.applyIndexReduce(indexreduce::IndexMax, dim); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } ////////////////////////////////////////////////////////////////////// @@ -1667,11 +1621,9 @@ TEST_F(NDArrayTest, applyReduce3Dot) { NDArray x(xBuff, xShapeInfo); NDArray y(yBuff, xShapeInfo); - auto result = x.applyReduce3(reduce3::Dot, &y); - ASSERT_TRUE(result->lengthOf() == 1); - ASSERT_NEAR(42, result->e(0), 1e-5); - - delete result; + auto result = x.applyReduce3(reduce3::Dot, y); + ASSERT_TRUE(result.lengthOf() == 1); + ASSERT_NEAR(42, result.e(0), 1e-5); } ////////////////////////////////////////////////////////////////////// @@ -1686,17 +1638,12 @@ TEST_F(NDArrayTest, applyAllReduce3EuclideanDistance) { NDArray y(yBuff, xShapeInfo); auto exp = NDArrayFactory::create('c', {2, 2}, {1.414214f, 1.414214f, 5.385165f, 5.385165f}); - auto result = x.applyAllReduce3(reduce3::EuclideanDistance, &y,{1}); - - // result->printIndexedBuffer("result"); + auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { float xBuff[] = {1, 2, 3, 4, 5, 6}; @@ -1709,12 +1656,10 @@ TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { NDArray y(yBuff, xShapeInfo); NDArray exp(expBuff, expShapeInfo); - auto result = x.applyAllReduce3(reduce3::EuclideanDistance, &y,{1}); + auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y ,{1}); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } @@ -1733,8 +1678,6 @@ TEST_F(NDArrayTest, TestVarianceAlongDimension1) { ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } ////////////////////////////////////////////////////////////////////// @@ -1751,8 +1694,6 @@ TEST_F(NDArrayTest, TestVarianceAlongDimension2) { auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {1}); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension3) { @@ -1765,9 +1706,8 @@ TEST_F(NDArrayTest, TestVarianceAlongDimension3) { auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } + ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension4) { @@ -1779,8 +1719,6 @@ TEST_F(NDArrayTest, TestVarianceAlongDimension4) { auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } ////////////////////////////////////////////////////////////////////// @@ -1796,9 +1734,9 @@ TEST_F(NDArrayTest, TestSubRowVector1) { NDArray target(x); NDArray exp(expBuff, xShapeInfo); - x.subRowVector(&y,&target); + x.subRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); ASSERT_TRUE(exp.equalsTo(&target)); } @@ -1815,9 +1753,9 @@ TEST_F(NDArrayTest, TestDivRowVector1) { NDArray target(x); NDArray exp(expBuff, xShapeInfo); - x.divRowVector(&y,&target); + x.divRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); ASSERT_TRUE(exp.equalsTo(&target)); } @@ -1834,9 +1772,9 @@ TEST_F(NDArrayTest, TestMulRowVector1) { NDArray target(x); NDArray exp(expBuff, xShapeInfo); - x.mulRowVector(&y,&target); + x.mulRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); ASSERT_TRUE(exp.equalsTo(&target)); } @@ -1895,7 +1833,7 @@ TEST_F(NDArrayTest, TestBroadcast_1) { bias.linspace(1); - input.applyBroadcast(broadcast::Add, {1}, &bias); + input.applyBroadcast(broadcast::Add, {1}, bias, input); //input.printBuffer("result"); ASSERT_TRUE(exp.equalsTo(&input)); @@ -2457,7 +2395,7 @@ TEST_F(NDArrayTest, Test_Lambda_1) { return _val + 3.0f; }; - x.applyLambda(lambda); + x.applyLambda(lambda, x); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -2472,7 +2410,7 @@ TEST_F(NDArrayTest, Test_Lambda_2) { return _x + _y + 1.0f; }; - x.applyPairwiseLambda(&y, lambda); + x.applyPairwiseLambda(y, lambda, x); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -2487,7 +2425,7 @@ TEST_F(NDArrayTest, Test_Lambda_3) { return (_x + _y) * 2; }; - x.applyPairwiseLambda(&y, lambda); + x.applyPairwiseLambda(y, lambda, x); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -2518,8 +2456,6 @@ TEST_F(NDArrayTest, Test_diagonal_1) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2533,8 +2469,6 @@ TEST_F(NDArrayTest, Test_diagonal_2) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2548,8 +2482,6 @@ TEST_F(NDArrayTest, Test_diagonal_3) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2563,8 +2495,6 @@ TEST_F(NDArrayTest, Test_diagonal_4) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2578,8 +2508,6 @@ TEST_F(NDArrayTest, Test_diagonal_5) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2593,8 +2521,6 @@ TEST_F(NDArrayTest, Test_diagonal_6) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2608,8 +2534,6 @@ TEST_F(NDArrayTest, Test_diagonal_7) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2623,8 +2547,6 @@ TEST_F(NDArrayTest, Test_diagonal_8) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2638,8 +2560,6 @@ TEST_F(NDArrayTest, Test_diagonal_9) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } @@ -2654,8 +2574,6 @@ TEST_F(NDArrayTest, Test_diagonal_10) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2669,8 +2587,6 @@ TEST_F(NDArrayTest, Test_diagonal_11) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2684,8 +2600,6 @@ TEST_F(NDArrayTest, Test_diagonal_12) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } //////////////////////////////////////////////////////////////////// @@ -2699,8 +2613,6 @@ TEST_F(NDArrayTest, Test_diagonal_13) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } //////////////////////////////////////////////////////////////////// @@ -2714,8 +2626,6 @@ TEST_F(NDArrayTest, Test_diagonal_14) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2729,8 +2639,6 @@ TEST_F(NDArrayTest, Test_diagonal_15) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2744,8 +2652,6 @@ TEST_F(NDArrayTest, Test_diagonal_16) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2759,8 +2665,6 @@ TEST_F(NDArrayTest, Test_diagonal_17) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2774,8 +2678,6 @@ TEST_F(NDArrayTest, Test_diagonal_18) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp index 4f8d38e76..4507086f5 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp @@ -196,12 +196,10 @@ TEST_F(NDArrayTest2, Test_AllReduce3_1) { auto y = NDArrayFactory::create('c', {2, 3}, {2, 3, 4, 2, 3, 4}); auto exp = NDArrayFactory::create('c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205}); - auto z = x.applyAllReduce3(reduce3::EuclideanDistance, &y, {1}, nullptr); + auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - delete z; } //////////////////////////////////////////////////////////////////// @@ -210,12 +208,10 @@ TEST_F(NDArrayTest2, Test_AllReduce3_2) { auto y = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); auto exp = NDArrayFactory::create('c', {2, 2}, {0., 1.73205, 1.73205, 0.}); - auto z = x.applyAllReduce3(reduce3::EuclideanDistance, &y, {1}, nullptr); + auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - delete z; } //////////////////////////////////////////////////////////////////// @@ -278,7 +274,7 @@ TEST_F(NDArrayTest2, Test_Streamline_1) { ASSERT_TRUE(x.isSameShape(&y)); ASSERT_TRUE(x.equalsTo(&y)); - ASSERT_FALSE(x.isSameShapeStrict(&y)); + ASSERT_FALSE(x.isSameShapeStrict(y)); } @@ -306,7 +302,7 @@ TEST_F(NDArrayTest2, Test_Enforce_1) { x.enforce({4, 4}, 'c'); - ASSERT_TRUE(exp.isSameShapeStrict(&x)); + ASSERT_TRUE(exp.isSameShapeStrict(x)); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -315,7 +311,7 @@ TEST_F(NDArrayTest2, TestVector_1) { auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - x.addiRowVector(&row); + x.addiRowVector(row); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -359,7 +355,7 @@ TEST_F(NDArrayTest2, tileToShape_test1) { auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); - x.tileToShape({2,2,2}); + x.tileToShape({2,2,2}, x); ASSERT_TRUE(x.isSameShape(&exp)); ASSERT_TRUE(x.equalsTo(&exp)); @@ -371,7 +367,7 @@ TEST_F(NDArrayTest2, tileToShape_test2) { auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); - x.tileToShape({2,3,2}); + x.tileToShape({2,3,2}, x); ASSERT_TRUE(x.isSameShape(&exp)); ASSERT_TRUE(x.equalsTo(&exp)); @@ -384,7 +380,7 @@ TEST_F(NDArrayTest2, tileToShape_test3) { auto result = NDArrayFactory::create('c', {2, 2, 2}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); - x.tileToShape({2,2,2}, &result); + x.tileToShape({2,2,2}, result); // result.printIndexedBuffer(); ASSERT_TRUE(result.isSameShape(&exp)); @@ -398,7 +394,7 @@ TEST_F(NDArrayTest2, tileToShape_test4) { auto result = NDArrayFactory::create('c', {2, 3, 2}); auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); - x.tileToShape({2,3,2}, &result); + x.tileToShape({2,3,2}, result); ASSERT_TRUE(result.isSameShape(&exp)); ASSERT_TRUE(result.equalsTo(&exp)); @@ -418,7 +414,7 @@ TEST_F(NDArrayTest2, Test_TriplewiseLambda_1) { return _t + _u + _v + extra; }; - t.applyTriplewiseLambda(&u, &v, la); + t.applyTriplewiseLambda(u, v, la, t); ASSERT_TRUE(t.equalsTo(&exp)); } @@ -436,7 +432,7 @@ TEST_F(NDArrayTest2, Test_TriplewiseLambda_2) { return _t + _u + _v + extra; }; - t.applyTriplewiseLambda(&u, &v, la); + t.applyTriplewiseLambda(u, v, la, t); ASSERT_TRUE(t.equalsTo(&exp)); } @@ -450,7 +446,7 @@ TEST_F(NDArrayTest2, Test_Indexed_Lambda) { return (float) _idx; }; - x.applyIndexedLambda(lambda); + x.applyIndexedLambda(lambda, x); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -565,7 +561,7 @@ TEST_F(NDArrayTest2, fillAsTriangular_test1) { auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto exp = NDArrayFactory::create('c', {4, 4}, {1,0,0,0,5,6,0,0,9,10,11,0 ,13,14,15,16}); - x.fillAsTriangular(0., 0, 0, 'u'); + x.fillAsTriangular(0., 0, 0, x, 'u'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -578,7 +574,7 @@ TEST_F(NDArrayTest2, fillAsTriangular_test2) { auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto exp = NDArrayFactory::create('c', {4, 4}, {0,0,0,0,5,0,0,0,9,10,0 ,0 ,13,14,15,0}); - x.fillAsTriangular(0., 0, -1, 'u'); + x.fillAsTriangular(0., 0, -1, x, 'u'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -591,7 +587,7 @@ TEST_F(NDArrayTest2, fillAsTriangular_test3) { auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto exp = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,0,6,7,8,0,0 ,11,12,0 ,0 , 0,16}); - x.fillAsTriangular(0., 0, 0, 'l'); + x.fillAsTriangular(0., 0, 0, x, 'l'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -604,7 +600,7 @@ TEST_F(NDArrayTest2, fillAsTriangular_test4) { auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto exp = NDArrayFactory::create('c', {4, 4}, {0,2,3,4,0,0,7,8,0,0 , 0,12, 0, 0, 0, 0}); - x.fillAsTriangular(0., 1, 0, 'l'); + x.fillAsTriangular(0., 1, 0, x, 'l'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -616,13 +612,10 @@ TEST_F(NDArrayTest2, Test_DType_Conversion_1) { auto xd = x.template asT(); - auto xf = xd->template asT(); + auto xf = xd.template asT(); ASSERT_TRUE(x.isSameShape(xf)); ASSERT_TRUE(x.equalsTo(xf)); - - delete xf; - delete xd; } //////////////////////////////////////////////////////////////////// @@ -677,7 +670,7 @@ TEST_F(NDArrayTest2, permute_test4) { // arr1P->printShapeInfo(); // ASSERT_TRUE(arr1.isSameShapeStrict(&arr2)); - ASSERT_TRUE(arr1P.isSameShapeStrict(&arr2)); + ASSERT_TRUE(arr1P.isSameShapeStrict(arr2)); delete []arr1Buffer; delete []arr2Buffer; } @@ -773,11 +766,9 @@ TEST_F(NDArrayTest2, allTensorsAlongDimension_test1) { // set->at(0)->printShapeInfo(); // set->at(0)->printIndexedBuffer(); - ASSERT_TRUE(set->size() == 1); - ASSERT_TRUE(exp.isSameShape(set->at(0))); - ASSERT_TRUE(exp.equalsTo(set->at(0))); - - delete set; + ASSERT_TRUE(set.size() == 1); + ASSERT_TRUE(exp.isSameShape(set.at(0))); + ASSERT_TRUE(exp.equalsTo(set.at(0))); } //////////////////////////////////////////////////////////////////// @@ -838,7 +829,7 @@ TEST_F(NDArrayTest2, scalar_set_test2) { TEST_F(NDArrayTest2, big_dup_test) { // auto arr = NDArrayFactory::linspace(1.0f, 10000000.0f, 100000000); auto arr = NDArrayFactory::linspace(1.0f, 1000.0f, 10000); - auto dup = arr->dup('c'); + auto dup = new NDArray(arr->dup('c')); ASSERT_EQ(*arr, *dup); @@ -920,8 +911,7 @@ TEST_F(NDArrayTest2, test_subarray_ews_1) { NDArray x('c', {10, 5}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - ASSERT_EQ(5, subArr1->ews()); - delete subArr1; + ASSERT_EQ(5, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// @@ -930,8 +920,7 @@ TEST_F(NDArrayTest2, test_subarray_ews_2) { NDArray x('f', {10, 5}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - ASSERT_EQ(1, subArr1->ews()); - delete subArr1; + ASSERT_EQ(1, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// @@ -940,8 +929,7 @@ TEST_F(NDArrayTest2, test_subarray_ews_3) { NDArray x('c', {10, 5}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - ASSERT_EQ(1, subArr1->ews()); - delete subArr1; + ASSERT_EQ(1, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// @@ -950,8 +938,7 @@ TEST_F(NDArrayTest2, test_subarray_ews_4) { NDArray x('f', {10, 5}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - ASSERT_EQ(10, subArr1->ews()); - delete subArr1; + ASSERT_EQ(10, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// @@ -1065,9 +1052,8 @@ TEST_F(NDArrayTest2, test_subarray_interval_1) { NDArray x('f', {10, 10}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0,9)}); - ASSERT_EQ(10, subArr1->sizeAt(0)); - ASSERT_EQ(9, subArr1->sizeAt(1)); - delete subArr1; + ASSERT_EQ(10, subArr1.sizeAt(0)); + ASSERT_EQ(9, subArr1.sizeAt(1)); } TEST_F(NDArrayTest2, test_subarray_interval_2) { @@ -1075,9 +1061,8 @@ TEST_F(NDArrayTest2, test_subarray_interval_2) { NDArray x('c', {10, 10}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0,9)}); - ASSERT_EQ(10, subArr1->sizeAt(0)); - ASSERT_EQ(9, subArr1->sizeAt(1)); - delete subArr1; + ASSERT_EQ(10, subArr1.sizeAt(0)); + ASSERT_EQ(9, subArr1.sizeAt(1)); } TEST_F(NDArrayTest2, test_subarray_3d_cf) { @@ -1117,7 +1102,7 @@ TEST_F(NDArrayTest2, test_broadcast_column_2) { auto e = NDArrayFactory::create('c', {5, 10}); e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &x, false); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x, false); ASSERT_EQ(e, x); } @@ -1128,7 +1113,7 @@ TEST_F(NDArrayTest2, test_broadcast_column_3) { auto e = NDArrayFactory::create('c', {5, 10}); e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &x); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); ASSERT_EQ(e, x); } @@ -1139,7 +1124,7 @@ TEST_F(NDArrayTest2, test_broadcast_column_4) { auto e = NDArrayFactory::create('f', {10, 5}); e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &x); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); ASSERT_EQ(e, x); } @@ -1171,7 +1156,7 @@ TEST_F(NDArrayTest2, test_not_tiled_2) { TEST_F(NDArrayTest2, test_long_sum_1) { auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto z = x.reduceAlongDims(reduce::Sum, {0}); + auto z = x.reduceAlongDimension(reduce::Sum, {0}); } ////////////////////////////////////////////////////////////////////// @@ -1216,7 +1201,7 @@ TEST_F(NDArrayTest2, trueBroadcast_1) { NDArray z('c', {2, 3}, nd4j::DataType::DOUBLE); auto exp = x - y; - x.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), &y, &z, true); + x.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), y, z); // exp.printIndexedBuffer(); // z.printIndexedBuffer(); @@ -1232,7 +1217,7 @@ TEST_F(NDArrayTest2, reduce_1) { arr6.linspace(1); - NDArray* arr6s = arr6.reduceAlongDimension(nd4j::reduce::Sum, {2,3}); + NDArray arr6s = arr6.reduceAlongDimension(nd4j::reduce::Sum, {2,3}); for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { @@ -1254,8 +1239,6 @@ TEST_F(NDArrayTest2, reduce_1) { // arr6s->printIndexedBuffer(); ASSERT_TRUE(exp.equalsTo(arr6s)); - - delete arr6s; } ////////////////////////////////////////////////////////////////////// @@ -1265,23 +1248,17 @@ TEST_F(NDArrayTest2, reduce3_1) { NDArray y('c', {1,4}, {2,3,4,5}); NDArray exp('c', {4}, {1,1,1,1}); - NDArray* z = x.applyReduce3(nd4j::reduce3::EuclideanDistance, &y, {0}, nullptr); - // z->printShapeInfo(); - // z->printIndexedBuffer(); + NDArray z = x.applyReduce3(nd4j::reduce3::EuclideanDistance, y, {0}, nullptr); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - delete z; } TEST_F(NDArrayTest2, all_tads_1) { auto x = NDArrayFactory::create('c', {3, 5}); auto arrays = x.allTensorsAlongDimension({1}); - ASSERT_EQ(3, arrays->size()); - - delete arrays; + ASSERT_EQ(3, arrays.size()); } TEST_F(NDArrayTest2, test_trueBroadcast_empty_1) { diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index e426eeb1f..42eb50be0 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -119,13 +119,15 @@ TEST_F(NativeOpsTests, ExecIndexReduce_1) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execIndexReduceScalar(nullptr, indexreduce::IndexMax, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr); + nullptr, + &expBuf, exp.shapeInfo(), + nullptr); ASSERT_TRUE(exp.e(0) == 4LL); #endif @@ -140,15 +142,18 @@ TEST_F(NativeOpsTests, ExecIndexReduce_2) { printf("Unsupported for cuda now.\n"); #else NDArray dimension = NDArrayFactory::create({}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimensionBuf(dimension.dataBuffer()); + ::execIndexReduce(nullptr, indexreduce::IndexMax, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, - dimension.buffer(), dimension.shapeInfo(), - nullptr, nullptr); + &expBuf, exp.shapeInfo(), + nullptr, + &dimensionBuf, dimension.shapeInfo(), + nullptr); ASSERT_TRUE(exp.e(0) == 24LL); #endif @@ -166,16 +171,21 @@ TEST_F(NativeOpsTests, ExecBroadcast_1) { #else auto dimension = NDArrayFactory::create('c', {1}, {1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execBroadcast(nullptr, broadcast::Add, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, - y.buffer(), y.shapeInfo(), - nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, - dimension.buffer(), dimension.shapeInfo(), - nullptr, nullptr); + &xBuf, x.shapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), + nullptr, + &expBuf, exp.shapeInfo(), + nullptr, + &dimBuf, dimension.shapeInfo(), + nullptr); ASSERT_TRUE(exp.e(0) == 3.); #endif @@ -194,17 +204,18 @@ printf("Unsupported for cuda now.\n"); int dimd = 0; auto dimension = NDArrayFactory::create('c', {1}, {dimd}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execBroadcastBool(nullptr, broadcast::EqualTo, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, - y.buffer(), y.shapeInfo(), - nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, - nullptr, - dimension.buffer(), dimension.shapeInfo(), - nullptr, nullptr); + &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), nullptr, nullptr, + &dimBuf, dimension.shapeInfo(), + nullptr); ASSERT_TRUE(exp.e(1) && !exp.e(0)); #endif @@ -219,14 +230,15 @@ TEST_F(NativeOpsTests, ExecPairwise_1) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execPairwiseTransform(nullptr, pairwise::Add, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, - y.buffer(), y.shapeInfo(), - nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), nullptr, nullptr); ASSERT_TRUE(exp.e(5) == 8.); #endif @@ -243,14 +255,15 @@ TEST_F(NativeOpsTests, ExecPairwise_2) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execPairwiseTransformBool(nullptr, pairwise::And, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, - y.buffer(), y.shapeInfo(), - nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), nullptr, nullptr); ASSERT_TRUE(exp.e(5) && !exp.e(4)); #endif @@ -266,14 +279,14 @@ TEST_F(NativeOpsTests, ReduceTest_1) { printf("Unsupported for cuda now.\n"); #else auto dimension = NDArrayFactory::create('c', {1}, {1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execReduceFloat(nullptr, reduce::Mean, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr); + &expBuf, exp.shapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce Mean"); ASSERT_TRUE(exp.e(0) == 13.); @@ -289,14 +302,14 @@ TEST_F(NativeOpsTests, ReduceTest_2) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execReduceSame(nullptr, reduce::Sum, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr); + &expBuf, exp.shapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce Sum"); ASSERT_TRUE(exp.e(0) == 325.); @@ -312,14 +325,14 @@ TEST_F(NativeOpsTests, ReduceTest_3) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execReduceBool(nullptr, reduce::All, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr); + &expBuf, exp.shapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); ASSERT_TRUE(exp.e(0) == true); @@ -335,14 +348,14 @@ TEST_F(NativeOpsTests, ReduceTest_4) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execReduceLong(nullptr, reduce::CountNonZero, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr); + &expBuf, exp.shapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce CountNonZero"); ASSERT_TRUE(exp.e(0) == 25LL); @@ -359,15 +372,16 @@ TEST_F(NativeOpsTests, ReduceTest_5) { printf("Unsupported for cuda now.\n"); #else auto dimension = NDArrayFactory::create({0, 1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execReduceLong2(nullptr, reduce::CountNonZero, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo()); + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce CountNonZero"); ASSERT_TRUE(exp.e(0) == 25LL); @@ -389,15 +403,17 @@ TEST_F(NativeOpsTests, ReduceTest_6) { x.p(10, 0); x.p(11, 0); x.p(15, 0); x.p(16, 0); x.p(17, 0); x.p(20, 0); x.p(21, 0); x.p(22, 0); x.p(23, 0); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execReduceLong2(nullptr, reduce::CountNonZero, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo()); + &expBuf, exp.shapeInfo(), nullptr, + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce CountNonZero"); ASSERT_TRUE(exp.equalsTo(z)); @@ -421,15 +437,16 @@ TEST_F(NativeOpsTests, ReduceTest_7) { x.linspace(1.0); x.syncToDevice(); dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execReduceFloat2(extra, reduce::Mean, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo()); + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce Mean"); ASSERT_TRUE(exp.equalsTo(z)); @@ -453,16 +470,16 @@ TEST_F(NativeOpsTests, ReduceTest_8) { x.syncToDevice(); dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); ::execReduceSame2(extra, reduce::Sum, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - z.buffer(), z.shapeInfo(), - z.specialBuffer(), z.specialShapeInfo(), - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo()); + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce Sum"); ASSERT_TRUE(exp.equalsTo(z)); @@ -485,15 +502,17 @@ TEST_F(NativeOpsTests, ReduceTest_9) { x.syncToDevice(); dimension.syncToHost(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execReduceBool2(extra, reduce::All, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo()); + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); ASSERT_TRUE(exp.equalsTo(z)); @@ -518,15 +537,16 @@ TEST_F(NativeOpsTests, Reduce3Test_1) { y.assign(2.); x.syncToDevice(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execReduce3(extra, reduce3::Dot, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo()); + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo()); //z.printIndexedBuffer("Z"); //exp.printIndexedBuffer("Reduce3 Dot"); ASSERT_TRUE(exp.equalsTo(z)); @@ -551,15 +571,16 @@ TEST_F(NativeOpsTests, Reduce3Test_2) { y.assign(2.); x.syncToDevice(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execReduce3Scalar(extra, reduce3::Dot, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo()); + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce3 Dot"); ASSERT_TRUE(exp.equalsTo(z)); @@ -585,17 +606,18 @@ TEST_F(NativeOpsTests, Reduce3Test_3) { x.syncToDevice(); dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execReduce3Tad(extra, reduce3::Dot, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); @@ -630,17 +652,18 @@ TEST_F(NativeOpsTests, Reduce3Test_4) { auto hTADShapeInfoY = tadPackY.primaryShapeInfo(); auto hTADOffsetsY = tadPackY.primaryOffsets(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execReduce3All(extra, reduce3::Dot, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), hTADShapeInfoX, hTADOffsetsX, hTADShapeInfoY, hTADOffsetsY); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); @@ -667,14 +690,16 @@ TEST_F(NativeOpsTests, ScalarTest_1) { //y.assign(2.); x.syncToDevice(); z.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execScalar(extra, scalar::Multiply, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), nullptr); + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); ASSERT_TRUE(exp.equalsTo(z)); @@ -700,14 +725,16 @@ TEST_F(NativeOpsTests, ScalarTest_2) { //y.assign(2.); x.syncToDevice(); z.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execScalarBool(extra, scalar::GreaterThan, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), nullptr); + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15) != z.e(15)); @@ -726,13 +753,14 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) { printf("Unsupported for CUDA platform yet.\n"); return; #endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execSummaryStatsScalar(extra, variance::SummaryStatsVariance, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), false); + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Standard Variance"); ASSERT_TRUE(exp.equalsTo(z)); @@ -751,13 +779,13 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) { printf("Unsupported for CUDA platform yet.\n"); return; #endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execSummaryStats(extra, variance::SummaryStatsVariance, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), false); + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Standard Variance"); ASSERT_TRUE(exp.equalsTo(z)); @@ -777,15 +805,16 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) { return; #endif auto dimensions = NDArrayFactory::create({0, 1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimensions.dataBuffer()); + ::execSummaryStatsTad(extra, variance::SummaryStatsVariance, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - dimensions.buffer(), dimensions.shapeInfo(), - dimensions.specialBuffer(), dimensions.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimensions.shapeInfo(), dimensions.specialShapeInfo(), false, nullptr, nullptr); // x.printIndexedBuffer("Input"); @@ -807,13 +836,15 @@ TEST_F(NativeOpsTests, TransformTest_1) { return; #endif z.linspace(1.); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execTransformFloat(extra, transform::Sqrt, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Sqrt is"); @@ -834,13 +865,15 @@ TEST_F(NativeOpsTests, TransformTest_2) { return; #endif z.linspace(1.); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execTransformSame(extra, transform::Square, - z.buffer(), z.shapeInfo(), - z.specialBuffer(), z.specialShapeInfo(), - - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Square is"); @@ -864,13 +897,14 @@ TEST_F(NativeOpsTests, TransformTest_3) { z.assign(true); x.p(24, -25); z.p(24, false); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execTransformBool(extra, transform::IsPositive, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("IsPositive"); @@ -894,13 +928,13 @@ TEST_F(NativeOpsTests, TransformTest_4) { return; #endif //z.linspace(1.); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execTransformStrict(extra, transform::Cosine, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Cosine"); @@ -932,17 +966,18 @@ TEST_F(NativeOpsTests, ScalarTadTest_1) { auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execScalarTad(extra, scalar::Multiply, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr, - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); @@ -977,17 +1012,21 @@ TEST_F(NativeOpsTests, ScalarTadTest_2) { auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); z.assign(true); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execScalarBoolTad(extra, scalar::And, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), + y.specialShapeInfo(), nullptr, - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo(), tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("And"); @@ -1095,9 +1134,11 @@ TEST_F(NativeOpsTests, PullRowsTest_1) { #ifdef __CUDABLAS__ nativeStart[1] = (x.getContext()->getCudaStream()); #endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); - pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), - z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.getSpecialShapeInfo(), + &zBuf, z.getShapeInfo(), z.specialShapeInfo(), 4, pidx, xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); @@ -1250,7 +1291,9 @@ TEST_F(NativeOpsTests, RandomTest_1) { #endif graph::RandomGenerator rng(1023, 119); double p = 0.5; - ::execRandom(extra, random::BernoulliDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p); + OpaqueDataBuffer zBuf(z.dataBuffer()); + + ::execRandom(extra, random::BernoulliDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_2) { @@ -1264,7 +1307,10 @@ TEST_F(NativeOpsTests, RandomTest_2) { x.linspace(0, 0.01); graph::RandomGenerator rng(1023, 119); double p = 0.5; - ::execRandom2(extra, random::DropOut, &rng, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + + ::execRandom2(extra, random::DropOut, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_3) { @@ -1280,7 +1326,12 @@ TEST_F(NativeOpsTests, RandomTest_3) { x.linspace(1, -0.01); graph::RandomGenerator rng(1023, 119); double p = 0.5; - ::execRandom3(extra, random::ProbablisticMerge, &rng, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + + ::execRandom3(extra, random::ProbablisticMerge, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &yBuf, + y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_4) { @@ -1316,6 +1367,10 @@ TEST_F(NativeOpsTests, SortTests_2) { #ifdef __CUDABLAS__ extras[1] = LaunchContext::defaultContext()->getCudaStream(); #endif +// OpaqueDataBuffer xBuf(x.dataBuffer()); +// OpaqueDataBuffer yBuf(y.dataBuffer()); +// OpaqueDataBuffer expBuf(exp.dataBuffer()); +// OpaqueDataBuffer dimBuf(exp.dataBuffer()); ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); k.tickWriteDevice(); @@ -1541,6 +1596,13 @@ TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) { ::deleteShapeList((Nd4jPointer) shapeList); } + +TEST_F(NativeOpsTests, interop_databuffer_tests_1) { + auto idb = ::allocateDataBuffer(100, 10, false); + auto ptr = ::dbPrimaryBuffer(idb); + ::deleteDataBuffer(idb); +} + //Uncomment when needed only - massive calculations //TEST_F(NativeOpsTests, BenchmarkTests_1) { // diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index 6d58e6e41..b62cbceea 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -92,9 +92,9 @@ TEST_F(ParityOpsTests, TestMinimum1) { TEST_F(ParityOpsTests, TestTear1) { auto input = NDArrayFactory::create('c', {10, 5}); auto tads = input.allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - ASSERT_EQ(5, tads->at(e)->lengthOf()); - tads->at(e)->assign((float) e + 1); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(5, tads.at(e)->lengthOf()); + tads.at(e)->assign((float) e + 1); } nd4j::ops::tear op; @@ -104,18 +104,17 @@ TEST_F(ParityOpsTests, TestTear1) { ASSERT_EQ(10, result->size()); for (int e = 0; e < result->size(); e++) - ASSERT_TRUE(tads->at(e)->equalsTo(result->at(e))); + ASSERT_TRUE(tads.at(e)->equalsTo(result->at(e))); delete result; - delete tads; } TEST_F(ParityOpsTests, TestUnstack1) { auto input = NDArrayFactory::create('c', {10, 5}); auto tads = input.allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - ASSERT_EQ(5, tads->at(e)->lengthOf()); - tads->at(e)->assign((float) e + 1); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(5, tads.at(e)->lengthOf()); + tads.at(e)->assign((float) e + 1); } nd4j::ops::unstack op; @@ -124,14 +123,10 @@ TEST_F(ParityOpsTests, TestUnstack1) { ASSERT_EQ(10, result->size()); - // result->at(0)->printShapeInfo("rz"); - // tads->at(0)->printShapeInfo("re"); - for (int e = 0; e < result->size(); e++) - ASSERT_TRUE(tads->at(e)->equalsTo(result->at(e))); + ASSERT_TRUE(tads.at(e)->equalsTo(result->at(e))); delete result; - delete tads; } @@ -139,9 +134,9 @@ TEST_F(ParityOpsTests, TestUnstack1) { TEST_F(ParityOpsTests, TestUnstack2) { auto input = NDArrayFactory::create('c', {5,2,6}); auto tads = input.allTensorsAlongDimension({0,1}); - for (int e = 0; e < tads->size(); e++) { - ASSERT_EQ(10, tads->at(e)->lengthOf()); - tads->at(e)->assign((float) e + 1); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(10, tads.at(e)->lengthOf()); + tads.at(e)->assign((float) e + 1); } nd4j::ops::unstack op; @@ -151,10 +146,9 @@ TEST_F(ParityOpsTests, TestUnstack2) { ASSERT_EQ(6, result->size()); for (int e = 0; e < result->size(); e++) - ASSERT_TRUE(tads->at(e)->equalsTo(result->at(e))); + ASSERT_TRUE(tads.at(e)->equalsTo(result->at(e))); delete result; - delete tads; } TEST_F(ParityOpsTests, TestUnstack3) { @@ -689,11 +683,10 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) { auto z = result->at(0); auto tads = z->allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - ASSERT_TRUE(bias.equalsTo(tads->at(e))); + for (int e = 0; e < tads.size(); e++) { + ASSERT_TRUE(bias.equalsTo(tads.at(e))); } - delete tads; delete result; } @@ -833,7 +826,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_8) { // z.printBuffer(); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(&z)); + ASSERT_TRUE(expected.isSameShapeStrict(z)); ASSERT_TRUE(expected.equalsTo(z)); } @@ -857,7 +850,7 @@ TEST_F(ParityOpsTests, scatterMax_test1) { auto exp = NDArrayFactory::create('c', {2, 2}, {10, 2, 3, 4}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -874,7 +867,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) { auto exp = NDArrayFactory::create('c', {1, 4}, {10, 2, 30, 4}); nd4j::ops::scatter_max op; - auto result = op.execute({&vec, &idc, &updates}, {}, {}); + auto result = op.execute({&vec, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -891,7 +884,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}); + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -908,7 +901,7 @@ TEST_F(ParityOpsTests, scatterMax_test4) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {true}); + auto result = op.execute({&matrix, &idc, &updates}, {}, {true}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -925,7 +918,7 @@ TEST_F(ParityOpsTests, scatterMax_test5) { auto exp = NDArrayFactory::create('c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}); + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -942,7 +935,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}); + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -960,7 +953,7 @@ TEST_F(ParityOpsTests, scatterMin_test1) { auto exp = NDArrayFactory::create('c', {2, 2}, {-1, 1, 3, 4}); nd4j::ops::scatter_min op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}); + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -977,7 +970,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) { auto exp = NDArrayFactory::create('c', {1, 4}, {1, 1, 3, 1}); nd4j::ops::scatter_min op; - auto result = op.execute({&vec, &idc, &updates}, {}, {}); + auto result = op.execute({&vec, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -994,7 +987,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8}); nd4j::ops::scatter_min op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}); + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 051c65988..970c119ca 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -59,6 +59,87 @@ public: fflush(stdout); } }; + +TEST_F(PlaygroundTests, test_avx) { + nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), ::binaryLevel()); +} + +/* +TEST_F(PlaygroundTests, test_s_0) { + auto x = NDArrayFactory::create('c', {32, 112, 112, 16}); + auto y = NDArrayFactory::create('c', {16}); + auto z = x.ulike(); + + std::vector values; + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + nd4j::ops::biasadd op; + + + for (int e = 0; e < 10000; e++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute(&ctx); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +} +*/ +/* +TEST_F(PlaygroundTests, test_s_1) { + auto x0 = NDArrayFactory::create('c', {32, 7, 7, 176}); + auto x1 = x0.ulike(); + auto x2 = x0.ulike(); + auto x3 = x0.ulike(); + auto x4 = x0.ulike(); + auto x5 = x0.ulike(); + + auto y = NDArrayFactory::create(3); + auto z = NDArrayFactory::create('c', {32, 7, 7, 1056}); + + Context ctx(1); + ctx.setInputArray(0, &x0); + ctx.setInputArray(1, &x1); + ctx.setInputArray(2, &x2); + ctx.setInputArray(3, &x3); + ctx.setInputArray(4, &x4); + ctx.setInputArray(5, &x5); + + ctx.setInputArray(6, &y); + ctx.setOutputArray(0, &z); + ctx.setBArguments({true}); + + std::vector values; + + nd4j::ops::concat op; + op.execute(&ctx); + + for (int e = 0; e < 1000; e++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute(&ctx); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +} +*/ + /* TEST_F(PlaygroundTests, test_s_1) { auto t = ::runLightBenchmarkSuit(true); @@ -341,4 +422,50 @@ TEST_F(PlaygroundTests, my) { delete variableSpace; } -*/ \ No newline at end of file + +#include + +TEST_F(PlaygroundTests, my) { + + const int N = 10000; + const Nd4jLong dim0(128), dim1(128), dim2(128); + + NDArray input('c', {dim0,dim1,dim2}, nd4j::DataType::DOUBLE); + NDArray mean('c', {dim1}, nd4j::DataType::DOUBLE); + NDArray variance('c', {dim1}, nd4j::DataType::DOUBLE); + NDArray gamma('c', {dim1}, nd4j::DataType::DOUBLE); + NDArray beta ('c', {dim1}, nd4j::DataType::DOUBLE); + + NDArray output('c', {dim0,dim1,dim2}, nd4j::DataType::DOUBLE); + + input.linspace(-100, 0.1); + mean.linspace(-50, 0.15); + variance.linspace(-5, 0.2); + gamma = 1.5; + beta = -2.5; + + // warm up + ops::helpers::batchnorm(&input, &mean, &variance, &gamma, &beta, &output, {1}, 1e-5); + + auto timeStart = std::chrono::system_clock::now(); + for (int i = 0; i < N; ++i) + ops::helpers::batchnorm(&input, &mean, &variance, &gamma, &beta, &output, {1}, 1e-5); + + auto timeEnd = std::chrono::system_clock::now(); + auto time = std::chrono::duration_cast ((timeEnd - timeStart)/N).count(); + + printf("time: %li \n", time); + +} + + +*/ + + + + + + + + + diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 5c3ca340b..0d5572ec6 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -275,8 +275,8 @@ TEST_F(RNGTests, Test_Gaussian_21) { #ifdef DEBUG_BUILD TEST_F(RNGTests, Test_Gaussian_22) { - auto x0 = NDArrayFactory::create('c', {10000, 1000}); - auto x1 = NDArrayFactory::create('c', {10000, 1000}); + auto x0 = NDArrayFactory::create('c', {1000, 800}); + auto x1 = NDArrayFactory::create('c', {1000, 800}); RandomLauncher::fillGaussian(nd4j::LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); @@ -304,7 +304,7 @@ TEST_F(RNGTests, Test_Gaussian_22) { } TEST_F(RNGTests, Test_Gaussian_3) { - auto x0 = NDArrayFactory::create('c', {10000000}); + auto x0 = NDArrayFactory::create('c', {800000}); RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, 1.0); @@ -381,8 +381,8 @@ TEST_F(RNGTests, Test_Truncated_2) { } TEST_F(RNGTests, Test_Truncated_21) { - auto x0 = NDArrayFactory::create('c', {1000, 1000}); - auto x1 = NDArrayFactory::create('c', {1000, 1000}); + auto x0 = NDArrayFactory::create('c', {100, 100}); + auto x1 = NDArrayFactory::create('c', {100, 100}); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); @@ -428,8 +428,8 @@ TEST_F(RNGTests, Test_Truncated_21) { } TEST_F(RNGTests, Test_Truncated_22) { - auto x0 = NDArrayFactory::create('c', {1000, 1000}); - auto x1 = NDArrayFactory::create('c', {1000, 1000}); + auto x0 = NDArrayFactory::create('c', {100, 100}); + auto x1 = NDArrayFactory::create('c', {100, 100}); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 2.0f, 4.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 2.0f, 4.0f); @@ -522,27 +522,20 @@ TEST_F(RNGTests, Test_Truncated_23) { } TEST_F(RNGTests, Test_Truncated_3) { - auto x0 = NDArrayFactory::create('c', {10000, 1000}); - auto x1 = NDArrayFactory::create('c', {10000, 1000}); + auto x0 = NDArrayFactory::create('c', {2000, 2000}); + auto x1 = NDArrayFactory::create('c', {2000, 2000}); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); ASSERT_TRUE(x0.equalsTo(&x1)); - //ASSERT_FALSE(x0.equalsTo(nexp0)); - //ASSERT_FALSE(x0.equalsTo(nexp1)); - //ASSERT_FALSE(x0.equalsTo(nexp2)); - // Check up distribution auto mean = x1.reduceNumber(reduce::Mean); // mean.printIndexedBuffer("Mean 1.0"); //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 2.0"); - //x1.printIndexedBuffer("Distribution TN"); ASSERT_NEAR(mean.e(0), 1.f, 0.001); ASSERT_NEAR(deviation.e(0), 2.f, 0.3); } @@ -1009,4 +1002,205 @@ TEST_F(RNGTests, test_uniform_119) { nd4j::ops::randomuniform op; auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {}); ASSERT_EQ(Status::OK(), status); -} \ No newline at end of file +} + +TEST_F(RNGTests, test_multinomial_1) { + + NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32); + NDArray expected('f', { 3, 3 }, { 0, 1, 2, 2, 0, 0, 1, 2, 1 }, nd4j::DataType::INT64); + NDArray output('f', { 3, 3 }, nd4j::DataType::INT64); + NDArray samples('f', { 1 }, { 3 }, nd4j::DataType::INT32); + + nd4j::ops::random_multinomial op; + RandomGenerator rng(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, false) ); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32); + NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64); + + auto result = op.execute({ &probsZ, &samples }, { }, { 1, INT64 }); + auto outputZ = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expectedZ.isSameShape(outputZ)); + ASSERT_TRUE(expectedZ.equalsTo(outputZ)); + delete result; +} + +TEST_F(RNGTests, test_multinomial_2) { + + NDArray samples('c', { 1 }, { 20 }, nd4j::DataType::INT32); + NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32); + NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, nd4j::DataType::INT64); + NDArray output('c', { 3, 20 }, nd4j::DataType::INT64); + + nd4j::ops::random_multinomial op; + RandomGenerator rng(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + NDArray probs2('c', { 5, 3 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32); + NDArray expected2('c', { 20, 3 }, { 0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2 }, nd4j::DataType::INT64); + NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, false)); + ASSERT_TRUE(expected2.isSameShape(output2)); + ASSERT_TRUE(expected2.equalsTo(output2)); +} + +TEST_F(RNGTests, test_multinomial_3) { + + NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32); + NDArray expected('c', { 4, 5 }, nd4j::DataType::INT64); + NDArray output('c', { 4, 5 }, nd4j::DataType::INT64); + NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32); + RandomGenerator rng(1234, 1234); + + nd4j::ops::random_multinomial op; + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, false)); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +TEST_F(RNGTests, test_multinomial_4) { + + NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32); + NDArray expected('c', { 5, 4 }, nd4j::DataType::INT64); + NDArray output('c', { 5, 4 }, nd4j::DataType::INT64); + NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32); + + RandomGenerator rng(1234, 1234); + nd4j::ops::random_multinomial op; + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, false)); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +TEST_F(RNGTests, test_multinomial_5) { + // multinomial as binomial if 2 classes used + int batchValue = 1; + int ClassValue = 2; + int Samples = 100000; + + NDArray samples('c', { 1 }, { 1.*Samples }, nd4j::DataType::INT32); + + NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, nd4j::DataType::FLOAT32); + + nd4j::ops::random_multinomial op; + + NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64); + RandomGenerator rng(1234, 1234); + + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, false)); + + auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = output.meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + // theoretical values for binomial + ASSERT_NEAR(0.5, deviation.e(0), 4e-3); // 1000000 3e-3); + ASSERT_NEAR(0.5, mean.e(0), 4e-3); // 1000000 3e-3); + + for (int i = 0; i < output.lengthOf(); i++) { + auto value = output.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + } + + auto resultR = op.execute({ &probs, &samples }, { }, { 1 }); + auto outputR = resultR->at(0); + ASSERT_EQ(Status::OK(), resultR->status()); + + deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = outputR->meanNumber(); + // printf("Random seed - Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + ASSERT_NEAR(0.5, deviation.e(0), 45e-3); // 1000000 35e-3); + ASSERT_NEAR(0.5, mean.e(0), 45e-3); // 1000000 35e-3); + + for (int i = 0; i < outputR->lengthOf(); i++) { + auto value = outputR->e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + } + + delete resultR; +} + + +TEST_F(RNGTests, test_multinomial_6) { + + int batchValue = 1; + int ClassValue = 5; + int Samples = 100000; + + NDArray samples('c', { 1 }, { 1. * Samples }, nd4j::DataType::INT32); + + nd4j::ops::random_multinomial op; + NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, nd4j::DataType::DOUBLE); + + // without seed + NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32); + + auto resultR = op.execute({ &probsR, &samples }, { }, { 0 }); + auto outputR = resultR->at(0); + ASSERT_EQ(Status::OK(), resultR->status()); + + NDArray countsR('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); + + for (int i = 0; i < outputR->lengthOf(); i++) { + auto value = outputR->e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + double* z = countsR.bufferAsT(); + z[value] += 1; + } + + for (int i = 0; i < countsR.lengthOf(); i++) { + auto c = countsR.e(i); + auto p = probExpect.e(i); + // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); + ASSERT_NEAR((c / Samples), p, 45e-3); // 1000000 35e-3); + } + + auto deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = outputR->meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + ASSERT_NEAR(1.2175, deviation.e(0), 45e-3); // 1000000 35e-3); + ASSERT_NEAR(2.906, mean.e(0), 45e-3); // 1000000 35e-3); + + delete resultR; + + RandomGenerator rng(1234, 1234); + NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32); + NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64); + + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); + + NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); + + for (int i = 0; i < output.lengthOf(); i++) { + auto value = output.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + double* z = counts.bufferAsT(); + z[value] += 1; + } + + for (int i = 0; i < counts.lengthOf(); i++) { + auto c = counts.e(i); + auto p = probExpect.e(i); + // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); + ASSERT_NEAR((c / Samples), p, 4e-3); // 1000000 3e-3); + } + + deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = output.meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + ASSERT_NEAR(1.2175, deviation.e(0), 5e-3); // 1000000 3e-3); + ASSERT_NEAR(2.906, mean.e(0), 5e-3); // 1000000 3e-3); +} diff --git a/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp b/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp index 3cdca2db6..404b95013 100644 --- a/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp @@ -35,17 +35,15 @@ TEST_F(ResultSetTests, basic_test_1) { auto x = NDArrayFactory::create('c', {3, 5}); auto tensors = x.allTensorsAlongDimension({1}); - ASSERT_EQ(3, tensors->size()); + ASSERT_EQ(3, tensors.size()); - ResultSet set = *tensors; - ASSERT_EQ(3, tensors->size()); + ResultSet set = tensors; + ASSERT_EQ(3, tensors.size()); ASSERT_EQ(3, set.size()); for (int e = 0; e < set.size(); e++) ASSERT_EQ(5, set.at(e)->lengthOf()); - for (int e = 0; e < tensors->size(); e++) - ASSERT_EQ(5, tensors->at(e)->lengthOf()); - - delete tensors; + for (int e = 0; e < tensors.size(); e++) + ASSERT_EQ(5, tensors.at(e)->lengthOf()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp b/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp index 3762d790c..41f8ed2d0 100644 --- a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp @@ -74,7 +74,7 @@ TEST_F(SessionLocalTests, BasicTests_2) { auto varSpace = storage.localVariableSpace(); auto arr = varSpace->getVariable(-1)->getNDArray(); - arr->applyScalar(nd4j::scalar::Add, (float) e+1); + arr->applyScalar(nd4j::scalar::Add, (float) e+1, *arr); } float lastValue = 0.0f; diff --git a/libnd4j/tests_cpu/layers_tests/StringTests.cpp b/libnd4j/tests_cpu/layers_tests/StringTests.cpp index 2ae236210..ec7821f21 100644 --- a/libnd4j/tests_cpu/layers_tests/StringTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StringTests.cpp @@ -81,7 +81,7 @@ TEST_F(StringTests, Basic_dup_1) { ASSERT_EQ(1, array.lengthOf()); ASSERT_EQ(0, array.rankOf()); - auto dup = array.dup(); + auto dup = new NDArray(array.dup()); auto z0 = array.e(0); auto z1 = dup->e(0); @@ -90,4 +90,26 @@ TEST_F(StringTests, Basic_dup_1) { ASSERT_EQ(f, z1); delete dup; +} + +TEST_F(StringTests, byte_length_test_1) { + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + + ASSERT_EQ(f.length(), StringUtils::byteLength(array)); +} + +TEST_F(StringTests, byte_length_test_2) { + auto array = NDArrayFactory::string('c', {2}, {"alpha", "beta"}); + + ASSERT_EQ(9, StringUtils::byteLength(array)); +} + +TEST_F(StringTests, test_split_1) { + auto split = StringUtils::split("alpha beta gamma", " "); + + ASSERT_EQ(3, split.size()); + ASSERT_EQ(std::string("alpha"), split[0]); + ASSERT_EQ(std::string("beta"), split[1]); + ASSERT_EQ(std::string("gamma"), split[2]); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/TadTests.cpp b/libnd4j/tests_cpu/layers_tests/TadTests.cpp index b4a631a8c..86e7264e8 100644 --- a/libnd4j/tests_cpu/layers_tests/TadTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/TadTests.cpp @@ -106,7 +106,7 @@ TEST_F(TadTests, TestShapeTad_1) { NDArray tadArr(tadBuff, tadShapeInfo); ASSERT_TRUE(numTads==1); - ASSERT_TRUE(input.isSameShapeStrict(&tadArr)); + ASSERT_TRUE(input.isSameShapeStrict(tadArr)); ASSERT_TRUE(input.equalsTo(&tadArr)); delete[] tadShapeInfo; @@ -133,24 +133,16 @@ TEST_F(TadTests, TadEdgeCase_1) { auto tad = array.tensorAlongDimension(0, {0, 1}); ASSERT_TRUE(exp.isSameShape(tad)); - - delete tad; } TEST_F(TadTests, TestEdgeCase_2) { - auto array = NDArrayFactory::create('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6}); - auto tad1 = array.tensorAlongDimension(1, {2}); + auto array = NDArrayFactory::create('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6}); for (int e = 0 ; e < array.lengthOf(); e++) { auto tad = array.tensorAlongDimension(e, {2}); - - ASSERT_NEAR(tad->e(0), array.e(e), 1e-5); - - delete tad; + ASSERT_NEAR(tad.e(0), array.e(e), 1e-5); } - - delete tad1; } TEST_F(TadTests, TadEdgeCase_2) { @@ -158,10 +150,7 @@ TEST_F(TadTests, TadEdgeCase_2) { auto tad = array.tensorAlongDimension(0, {1}); - // tad->printShapeInfo("TAD shape"); - ASSERT_EQ(3, tad->lengthOf()); - - delete tad; + ASSERT_EQ(3, tad.lengthOf()); } diff --git a/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp b/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp index 1139d6076..fa89fbcaa 100644 --- a/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp @@ -32,7 +32,9 @@ using namespace nd4j::graph; class ThreadsTests : public testing::Test { public: - + ThreadsTests() { + nd4j_printf("\n",""); + } }; TEST_F(ThreadsTests, th_test_1) { @@ -84,6 +86,18 @@ TEST_F(ThreadsTests, th_test_3) { ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 8, 3, 64)); } +TEST_F(ThreadsTests, th_test_5) { + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 32, 112, 112)); + + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 32, 112, 112)); + + for (auto e = 0; e < 6; e++) { + auto span = Span3::build(1, e, 6, 0, 32, 1, 0, 112, 1, 0, 112, 1); + + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } +} + TEST_F(ThreadsTests, th_test_4) { // typical conv cases ASSERT_EQ(2, ThreadsHelper::numberOfThreads2d(2, 32, 3)); diff --git a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp index fcdd1db3c..1a1915fdc 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp @@ -166,7 +166,6 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) { ASSERT_TRUE(floating.equalsTo(conv)); delete rv; - delete conv; } /* diff --git a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt index 07cae9ae3..fbba329e3 100644 --- a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt @@ -150,7 +150,7 @@ if ("${EXPERIMENTAL}" STREQUAL "yes") endif() # tests are always compiled with all ops included -SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLIBND4J_ALL_OPS=true") +SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLIBND4J_ALL_OPS=true -DDEFAULT_ENGINE=samediff::ENGINE_CPU") if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") # using Clang diff --git a/libnd4j/tests_cpu/run_tests.sh b/libnd4j/tests_cpu/run_tests.sh index 2932827d4..9b1271df6 100755 --- a/libnd4j/tests_cpu/run_tests.sh +++ b/libnd4j/tests_cpu/run_tests.sh @@ -39,6 +39,7 @@ do done CHIP="${CHIP:-cpu}" +export GTEST_OUTPUT="xml:../target/surefire-reports/TEST-${CHIP}-results.xml" # On Mac, make sure it can find libraries for GCC export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/ @@ -47,8 +48,9 @@ export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/ if [ -n "$BUILD_PATH" ]; then if which cygpath; then BUILD_PATH=$(cygpath -p $BUILD_PATH) + export GTEST_OUTPUT="xml:'..\target\surefire-reports\TEST-${CHIP}-results.xml'" fi export PATH="$PATH:$BUILD_PATH" fi -../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests --gtest_output="xml:../target/surefire-reports/TEST-${CHIP}-results.xml" +../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests diff --git a/libnd4j/windows.md b/libnd4j/windows.md index 884b2c3ee..57b40cb83 100644 --- a/libnd4j/windows.md +++ b/libnd4j/windows.md @@ -274,3 +274,7 @@ To build libnd4j with MKL: Then build libnd4j as before. You may have to be careful about having multiple BLAS implementations on your path. Ideally, have only MKL on the path while building libnd4j. Note: you may be able to get some additional performance on hyperthreaded processors by setting the system/environment variable MKL_DYNAMIC to have the value 'false'. + + +float16_nhcw +float16_nhwc \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml index b4a374baf..9a42f6bd0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml @@ -217,16 +217,6 @@ commons-net ${commons-net.version} - - org.nd4j - nd4j-buffer - ${project.version} - - - org.nd4j - nd4j-context - ${project.version} - net.ericaro neoitertools @@ -238,6 +228,16 @@ + + org.nd4j + nd4j-common + ${project.version} + + + org.bytedeco + javacpp + ${javacpp.version} + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index b0fc00bac..e38af27d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -92,20 +92,7 @@ import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul; import org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction; import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; -import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.*; import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul; import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp; import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean; @@ -232,10 +219,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; import org.nd4j.linalg.api.ops.impl.transforms.strict.*; -import org.nd4j.linalg.api.ops.random.custom.DistributionUniform; -import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; -import org.nd4j.linalg.api.ops.random.custom.RandomExponential; -import org.nd4j.linalg.api.ops.random.custom.RandomNormal; +import org.nd4j.linalg.api.ops.random.custom.*; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution; import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; @@ -384,6 +368,18 @@ public class DifferentialFunctionFactory { return new TruncatedNormalDistribution(sameDiff(), mean, stdev, shape).outputVariable(); } + public SDVariable randomGamma(SDVariable shape, SDVariable alpha, SDVariable beta, int... seeds) { + return new RandomGamma(sameDiff(), shape, alpha, beta, seeds).outputVariable(); + } + + public SDVariable randomPoisson(SDVariable shape, SDVariable rate, int... seeds) { + return new RandomPoisson(sameDiff(), shape, rate, seeds).outputVariable(); + } + + public SDVariable randomShuffle(SDVariable values, int... seeds) { + return new RandomShuffle(sameDiff(), values, seeds).outputVariable(); + } + /** * Exponential distribution: P(x) = lambda * exp(-lambda * x) * @@ -1411,6 +1407,10 @@ public class DifferentialFunctionFactory { return new PowDerivative(sameDiff(), iX, false, pow).outputVariable(); } + public SDVariable[] powBp(SDVariable x, SDVariable pow, SDVariable gradient) { + return new PowBp(sameDiff(), x, pow, gradient).outputVariables(); + } + public SDVariable mishDerivative(SDVariable iX) { return new MishDerivative(sameDiff(), iX, false).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/At.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/At.java index 5427b4cd7..e05d067c6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/At.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/At.java @@ -1,6 +1,7 @@ package org.nd4j.autodiff.listeners; import lombok.*; +import org.nd4j.autodiff.samediff.internal.FrameIter; /** * @@ -20,13 +21,14 @@ public class At { private int iteration; private int trainingThreadNum; private long javaThreadNum; + private FrameIter frameIter; private Operation operation; /** * @return A new instance with everything set to 0, and operation set to INFERENCE */ public static At defaultAt(){ - return new At(0, 0, 0, 0, Operation.INFERENCE); + return new At(0, 0, 0, 0, null, Operation.INFERENCE); } /** @@ -34,7 +36,7 @@ public class At { * @return A new instance with everything set to 0, except for the specified operation */ public static At defaultAt(@NonNull Operation op){ - return new At(0, 0, 0, 0, op); + return new At(0, 0, 0, 0, null, op); } /** @@ -76,7 +78,7 @@ public class At { * @return A copy of the current At instance */ public At copy(){ - return new At(epoch, iteration, trainingThreadNum, javaThreadNum, operation); + return new At(epoch, iteration, trainingThreadNum, javaThreadNum, frameIter, operation); } /** @@ -84,6 +86,6 @@ public class At { * @return A copy of the current instance, but with the specified operation */ public At copy(Operation operation){ - return new At(epoch, iteration, trainingThreadNum, javaThreadNum, operation); + return new At(epoch, iteration, trainingThreadNum, javaThreadNum, frameIter, operation); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java index a9862d253..d0501454c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java @@ -1,5 +1,6 @@ package org.nd4j.autodiff.listeners.debugging; +import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; @@ -113,16 +114,16 @@ public class ExecDebuggingListener extends BaseListener { if(co.tArgs() != null && co.tArgs().length > 0) { sb.append("\n\ttArgs=").append(Arrays.toString(co.tArgs())); } - INDArray[] inputs = co.inputArguments(); - INDArray[] outputs = co.outputArguments(); + val inputs = co.inputArguments(); + val outputs = co.outputArguments(); if(inputs != null ) { - for (int i = 0; i < inputs.length; i++) { - sb.append("\n\tInput[").append(i).append("]=").append(inputs[i].shapeInfoToString()); + for (int i = 0; i < inputs.size(); i++) { + sb.append("\n\tInput[").append(i).append("]=").append(inputs.get(i).shapeInfoToString()); } } if(outputs != null ) { - for (int i = 0; i < outputs.length; i++) { - sb.append("\n\tOutputs[").append(i).append("]=").append(outputs[i].shapeInfoToString()); + for (int i = 0; i < outputs.size(); i++) { + sb.append("\n\tOutputs[").append(i).append("]=").append(outputs.get(i).shapeInfoToString()); } } } else { @@ -156,22 +157,22 @@ public class ExecDebuggingListener extends BaseListener { if(co.tArgs() != null && co.tArgs().length > 0 ){ sb.append("op.addTArgument(").append(Arrays.toString(co.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n"); } - INDArray[] inputs = co.inputArguments(); - INDArray[] outputs = co.outputArguments(); + val inputs = co.inputArguments(); + val outputs = co.outputArguments(); if(inputs != null ) { - sb.append("INDArray[] inputs = new INDArray[").append(inputs.length).append("];\n"); - for (int i = 0; i < inputs.length; i++) { + sb.append("INDArray[] inputs = new INDArray[").append(inputs.size()).append("];\n"); + for (int i = 0; i < inputs.size(); i++) { sb.append("inputs[").append(i).append("] = "); - sb.append(createString(inputs[i])) + sb.append(createString(inputs.get(i))) .append(";\n"); } sb.append("op.addInputArgument(inputs);\n"); } if(outputs != null ) { - sb.append("INDArray[] outputs = new INDArray[").append(outputs.length).append("];\n"); - for (int i = 0; i < outputs.length; i++) { + sb.append("INDArray[] outputs = new INDArray[").append(outputs.size()).append("];\n"); + for (int i = 0; i < outputs.size(); i++) { sb.append("outputs[").append(i).append("] = "); - sb.append(createString(outputs[i])) + sb.append(createString(outputs.get(i))) .append(";\n"); } sb.append("op.addOutputArgument(outputs);\n"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java new file mode 100644 index 000000000..9b92b0412 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java @@ -0,0 +1,354 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.listeners.profiler.data.Phase; +import org.nd4j.autodiff.listeners.profiler.data.TraceEvent; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.AtomicBoolean; +import org.nd4j.linalg.util.ArrayUtil; +import org.nd4j.shade.jackson.databind.DeserializationFeature; +import org.nd4j.shade.jackson.databind.MapperFeature; +import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.SerializationFeature; + +import java.io.*; +import java.lang.management.ManagementFactory; +import java.text.DecimalFormat; +import java.util.*; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.LinkedBlockingDeque; + +/** + * SameDiff profiling listener: for profiling operation execution
+ * Writes profiles to a file in JSON format
+ * Format is Chrome profiler format. The output can be read by Google Chrome browser; open Chrome and go to: + * chrome://tracing and load the output JSON format data + *
+ * At present, only operation execution is profiled, not other aspects such as memory allocation and training-related + * functionality.
+ *
+ * Tracing can be configured in a few different ways via the builder, {@link #builder(File)}:
+ * (a) warmup - don't record traces for the first N iterations
+ * (b) "all" mode (default) - record all-iterations, with no limit (after warmup, if applicable)
+ * (c) "n iterations" mode: record at most the first N iterations (after warmup, if applicable)
+ * (d) "n ms" mod: record for at most N milliseconds since the start of the first op execution (after warmup, if applicable)
+ * + * Note: The Chrome Trace Event format can be found here:
+ * https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit + * SameDiff uses the JSON Array Format, as this can be written in an online/streaming manner.
+ * Conversely, TensorFlow uses the JSON Object Format.
+ *
+ * For summarizing, analyzing and comparing the results (SameDiff or TensorFlow format), see {@link org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer}
+ * + * @author Alex Black + */ +@Getter +@Slf4j +public class ProfilingListener extends BaseListener { + + private final File outputFile; + private final boolean all; + private final int warmup; + private final int nIter; + private final long nMs; + private final Operation[] operations; + + private final long pid; + private final long tid; + private Long firstOpStart = null; //Used for time termination + private int countTotalIter = 0; + private boolean logActive = false; + private long opStartNano; + + private Writer writer; + private ObjectMapper json; + + private final Thread fileWritingThread; + private final BlockingQueue writeQueue; + private final AtomicBoolean writing = new AtomicBoolean(false); + + protected ProfilingListener(@NonNull File outputFile, boolean all, int warmup, int nIter, long nMs, Operation[] operations) { + Preconditions.checkArgument(!outputFile.exists(), "Output file already exists: %s", outputFile); + this.outputFile = outputFile; + this.all = all; + this.warmup = warmup; + this.nIter = nIter; + this.nMs = nMs; + this.operations = operations; + + this.pid = getProcessId(); + this.tid = Thread.currentThread().getId(); + + try { + this.writer = new BufferedWriter(new FileWriter(outputFile, false)); + this.writer.write("["); //JSON array open (array close is optional for Chrome profiler format) + } catch (IOException e) { + throw new RuntimeException(e); + } + + this.json = jsonMapper(); + + //Set up a queue so file access doesn't add latency to the execution thread + writeQueue = new LinkedBlockingDeque<>(); + fileWritingThread = new Thread(new Runnable() { + @Override + public void run() { + try { + runHelper(); + } catch (Throwable t) { + log.error("Error when attempting to write results to file", t); + } + } + + public void runHelper() throws Exception { + while (true) { + TraceEvent te = writeQueue.take(); //Blocking + writing.set(true); + try { + String j = json.writeValueAsString(te); + writer.append(j); + writer.append(",\n"); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + writing.set(false); + } + } + } + }); + fileWritingThread.setDaemon(true); + fileWritingThread.start(); + } + + @Override + public boolean isActive(Operation operation) { + return operations == null || ArrayUtils.contains(operations, operation); + } + + @Override + public void operationStart(SameDiff sd, Operation op) { + this.logActive = operations == null || ArrayUtils.contains(operations, op); + } + + @Override + public void operationEnd(SameDiff sd, Operation op) { + if (this.logActive) { + while ((!writeQueue.isEmpty() || writing.get()) && fileWritingThread.isAlive()) { + //Wait for file writing thread to catch up + try { + Thread.sleep(100); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + try { + writer.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + this.logActive = false; + if (op == Operation.INFERENCE) { + //Increment for inference; iteration done is called only for TRAINING + countTotalIter++; + } + } + + @Override + public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) { + //Increment for training + if (logActive) { + countTotalIter++; + } + } + + @Override + public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + if (logActive) { + opStartNano = System.nanoTime(); + + if(!all && nMs > 0 && firstOpStart == null) + firstOpStart = opStartNano; + } + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + if (logActive) { + long now = System.nanoTime(); + + if (warmup > 0 && countTotalIter < warmup) { + return; //Skip due to warmup phase + } + + //Iteration termination + int terminationPt = this.nIter > 0 ? this.nIter : Integer.MAX_VALUE; + if (warmup > 0 && this.nIter > 0) + terminationPt += this.warmup; + + if (countTotalIter > terminationPt) { + logActive = false; + return; //Skip due to max number of itertions + } + + //Time termination + if(!all && nMs > 0 && (now - firstOpStart)/1000 > nMs) { + logActive = false; + return; + } + + TraceEvent event = TraceEvent.builder() + .name(op.getOp().opName()) + .categories(Collections.singletonList("Op")) + .ts(opStartNano / 1000) + .dur((now - opStartNano) / 1000) + .pid((int)pid) + .tid(tid) + .ph(Phase.X) + .args(Collections.singletonMap("name", op.getName())) + .build(); + + writeQueue.add(event); + } + } + + + private long getProcessId() { + // Note: may fail in some JVM implementations + // therefore fallback has to be provided + + // something like '@', at least in SUN / Oracle JVMs + final String jvmName = ManagementFactory.getRuntimeMXBean().getName(); + final int index = jvmName.indexOf('@'); + + if (index < 1) { + // part before '@' empty (index = 0) / '@' not found (index = -1) + return 0; + } + + try { + return Long.parseLong(jvmName.substring(0, index)); + } catch (NumberFormatException e) { + // ignore + } + return 0; + } + + /** + * Get a new JSON mapper for use in serializing/deserializing JSON format + */ + public static ObjectMapper jsonMapper() { + ObjectMapper json = new ObjectMapper(); + json.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + json.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + json.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false); + json.disable(SerializationFeature.INDENT_OUTPUT); //One line + + return json; + } + + /** + * Create a new builder + * @param outputFile Output file. Will be overwritten if file already exists + */ + public static Builder builder(File outputFile) { + return new Builder(outputFile); + } + + public static class Builder { + private final File outputFile; + private boolean all = true; + private int warmup = 0; + private int nIter = -1; + private long nMs = -1; + private Operation[] operations; + + public Builder(@NonNull File outputFile) { + this.outputFile = outputFile; + } + + /** + * If called, all data will be profiled with no limits (other than a warmup, if set) + */ + public Builder recordAll() { + this.all = true; + this.nIter = -1; + this.nMs = -1; + return this; + } + + /** + * Specify the number of warmup iterations - i.e., these will be excluded from profiling results + */ + public Builder warmup(int iterations) { + this.warmup = iterations; + return this; + } + + /** + * Set a limit on the maximum number of iterations to profile (after warmup, if any). + * Any ops executed after the specified number of iterations will not be profiled/recorded + */ + public Builder maxProfileIterations(int iterations) { + this.nIter = iterations; + this.all = false; + return this; + } + + /** + * Set a limit on the maximum duration for profiling, in milliseconds. + * Any ops executed after the specified amount of time since the first (non-warmup) operation start will not be + * profiled/recorded + */ + public Builder maxProfilerMilliseconds(long ms) { + this.nMs = ms; + this.all = false; + return this; + } + + /** + * Specify the operations (training, inference, etc) to profile. + * If not set, all operations are profiled + */ + public Builder operations(Operation... operations) { + this.operations = operations; + return this; + } + + /** + * Create the profiling listener + */ + public ProfilingListener build() { + return new ProfilingListener(outputFile, all, warmup, nIter, nMs, operations); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/Config.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/Config.java new file mode 100644 index 000000000..602a26f99 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/Config.java @@ -0,0 +1,42 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.comparison; + +import lombok.Builder; +import lombok.Data; +import lombok.experimental.Accessors; +import org.nd4j.linalg.function.BiFunction; + +import java.io.File; + +@Data +@Accessors(fluent = true) +@Builder +public class Config { + + private String p1Name; + private String p2Name; + private File profile1; + private File profile2; + private boolean profile1IsDir; + private boolean profile2IsDir; + @Builder.Default private ProfileAnalyzer.ProfileFormat profile1Format = ProfileAnalyzer.ProfileFormat.SAMEDIFF; + @Builder.Default private ProfileAnalyzer.ProfileFormat profile2Format = ProfileAnalyzer.ProfileFormat.SAMEDIFF; + @Builder.Default private ProfileAnalyzer.SortBy sortBy = ProfileAnalyzer.SortBy.PROFILE1_PC; + private BiFunction filter; //Return true to keep, false to remove + @Builder.Default private ProfileAnalyzer.OutputFormat format = ProfileAnalyzer.OutputFormat.TEXT; + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/OpStats.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/OpStats.java new file mode 100644 index 000000000..0949020af --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/OpStats.java @@ -0,0 +1,32 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.comparison; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.nd4j.list.NDArrayList; + +@AllArgsConstructor +@NoArgsConstructor +@Data +public class OpStats { + private String opInstanceName; + private String opName; + private int count; + private NDArrayList timesUs; + private Long sumUs; +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java new file mode 100644 index 000000000..421c13cb0 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java @@ -0,0 +1,571 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.comparison; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.FileUtils; +import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.listeners.profiler.ProfilingListener; +import org.nd4j.autodiff.listeners.profiler.data.Phase; +import org.nd4j.autodiff.listeners.profiler.data.TraceEvent; +import org.nd4j.autodiff.listeners.profiler.data.TraceEvents; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.converters.DifferentialFunctionClassHolder; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.list.NDArrayList; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.*; + +/** + * A profile analyzer, used for analyzing Chrome-format profiler dumps generated by both SameDiff's
+ * {@link ProfilingListener} and TensorFlow's profiler.
+ * Has methods for summarizing profiler statistics, as well as comparing two profiler dumps.
+ *
+ * Also supports analyzing/aggregating multiple JSON files in a directory, via the "...Directory(...)" methods. + *

+ * See {@link ProfilingListener}
+ * See {@link TraceEvent} + * + * @author Alex Black + */ +@Slf4j +public class ProfileAnalyzer { + + /** + * Chrome profiler supports 2 formats:
+ * SameDiff == JSON Array Format
+ * TensorFlow == JSON Object Format
+ */ + public enum ProfileFormat {SAMEDIFF, TENSORFLOW} + + /** + * Only applicable for profile comparisons.
+ * PROFILE1_PC - sort by profile 1 percentage of total time
+ * PROFILE2_PC - sort by profile 2 percentage of total time
+ * RATIO - sort by highest ratio (mean op time profile 1 / mean op time profile 2) + */ + public enum SortBy {PROFILE1_PC, PROFILE2_PC, RATIO} + + /** + * TEXT: Human readable, columns padded for alignment
+ * CSV: CSV format, comma separated + */ + public enum OutputFormat {TEXT,CSV} + + + /** + * Summarize and print to stdout the specified profile file + * + * @param file Profile file + * @param profileFormat Format of the profiler file + */ + public static void summarizeProfile(File file, ProfileFormat profileFormat) { + System.out.println(summarizeProfileStr(file, profileFormat)); + } + + /** + * Summarize and return as a string the specified profile file + * + * @param file Profile file + * @param profileFormat Format of the profiler file + */ + public static String summarizeProfileStr(File file, ProfileFormat profileFormat) { + TraceEvent[] events = getTraceEvents(file, profileFormat); + return summarizeTraceEvents(events); + } + + /** + * Aggregate, summarize and print to stdout all .json profile files in the specified directory (not recursive) + * + * @param dir Directory containing the profiles + * @param profileFormat Profile format + */ + public static void summarizeProfileDirectory(File dir, ProfileFormat profileFormat) { + System.out.println(summarizeProfileDirectoryStr(dir, profileFormat)); + } + + /** + * Aggregate, summarize and return as a String all .json profile files in the specified directory (not recursive) + * + * @param dir Directory containing the profiles + * @param profileFormat Profile format + */ + public static String summarizeProfileDirectoryStr(File dir, ProfileFormat profileFormat) { + return summarizeTraceEvents(getTraceEventsDir(dir, profileFormat)); + } + + /** + * Load, aggregate and return the TraceEvent object from all profiles in the specified directory + * + * @param dir Directory containing the profiles + * @param profileFormat Profile format + */ + public static TraceEvent[] getTraceEventsDir(File dir, ProfileFormat profileFormat) { + File[] files = dir.listFiles(); + Preconditions.checkState(files != null && files.length > 0, "No profiles found in directory: %s", dir); + List l = new ArrayList<>(); + for (File f : files) { + if (!f.getName().endsWith(".json")) { + log.info("Skipping non-JSON file in directory - {}", f.getAbsolutePath()); + continue; + } + TraceEvent[] e = getTraceEvents(f, profileFormat); + Collections.addAll(l, e); + } + return l.toArray(new TraceEvent[0]); + } + + /** + * Load and return the TraceEvent object from the specified profile file + * + * @param file Profile file + * @param profileFormat Profile format + */ + public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat) { + return getTraceEvents(file, profileFormat, true); + } + + public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat, boolean aggregateTFSubOps) { + ObjectMapper json = ProfilingListener.jsonMapper(); + + String content; + try { + content = FileUtils.readFileToString(file, StandardCharsets.UTF_8); + } catch (IOException e) { + throw new RuntimeException(e); + } + + if (!content.matches(".*]\\s*")) { + if (content.endsWith(",")) { + //Has comma, missing ] + content = content.substring(0, content.length() - 1) + "]"; + } else if (content.endsWith(",\n")) { + //Has comma and newline, missing ] + content = content.substring(0, content.length() - 2) + "]"; + } else { + content = content + "]"; + } + } + + TraceEvent[] events; + if (profileFormat == ProfileFormat.SAMEDIFF) { + try { + events = json.readValue(content, TraceEvent[].class); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else { + //TF format + TraceEvents traceEvents; + try { + traceEvents = json.readValue(content, TraceEvents.class); + } catch (IOException e) { + throw new RuntimeException(e); + } + events = traceEvents.getTraceEvents().toArray(new TraceEvent[0]); + + //Clean up TF format - sometimes things like "Softmax" are actually profiled as "_MklSoftmax" + //And we'll align TF names to SameDiff names + for (TraceEvent te : events) { + if (TF_PROFILE_ALIASES.containsKey(te.getName())) { + te.setName(TF_PROFILE_ALIASES.get(te.getName())); + } + + DifferentialFunction df = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(te.getName()); + if (df != null) { + te.setName(df.opName()); + } + } + + + if(aggregateTFSubOps){ + //For CUDA ops, TF will log sub-ops like: + //fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/task:0/device:GPU:0,async=false#@@cudnn::maxwell::gemm::computeOffsetsKernel(cudnn::maxwell::gemm::ComputeOffsetsParams) + //fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/task:0/device:GPU:0,async=false#@@maxwell_scudnn_128x64_relu_interior_nn + //fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/task:0/device:GPU:0,async=false#@@void tensorflow::functor::ShuffleInTensor3Simple(int, float const*, tensorflow::functor::Dimension<3>, float*) + //We'll join these into one op, then strip everything after the ":" to recover the op name + + //Also, TF has multiple sub-ops like this, sequentially, that need to be joined: + //19 = {TraceEvent@3157} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601259742, dur=466, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)" + //20 = {TraceEvent@3181} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601260229, dur=29, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)" + //21 = {TraceEvent@3206} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601260329, dur=31, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)" + //22 = {TraceEvent@3247} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601260390, dur=4998, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)" + + Map map = new HashMap<>(); //Key: Op name with ID + List out = new ArrayList<>(); + TraceEvent last = null; + for(TraceEvent te : events){ + if(last != null && last.getPh() == Phase.X && te.getPh() == Phase.X && + last.getName().equals(te.getName()) && + last.getArgs() != null && te.getArgs() != null && + last.getArgs().get("name").equals(te.getArgs().get("name")) && + last.getArgs().get("op").equals(te.getArgs().get("op"))){ + //Aggregate - same names, ops, etc + last.setDur(last.getDur() + te.getDur()); + continue; + } + + last = te; + if(te.getArgs() == null || te.getArgs().isEmpty()){ + out.add(te); + continue; + } + + + String n = (String) te.getArgs().get("name"); + + //Aggregate by op name... + //"fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/..." -> "fire2/e1x1/Conv2D" + //We're relying on TF's "one iteration per json file" here + if(n.matches("[\\w/_-]+:[\\w/_-]+#id=\\d+.*")) { + int idx = n.indexOf("#"); + String sub1 = n.substring(0, idx); + String sub; + if (sub1.contains(":")) { + sub = sub1.substring(0, sub1.lastIndexOf(":")); + } else { + sub = sub1; + } + if (map.containsKey(sub)) { + TraceEvent t = map.get(sub); + Long dur = t.getDur(); + if (dur == null && te.getDur() == null) + continue; + t.setDur(dur == null ? te.getDur() : dur + (te.getDur() == null ? 0 : te.getDur())); + } else { + map.put(sub, te); + out.add(te); + } + } else { + if(map.containsKey(n)){ + TraceEvent t = map.get(n); + t.setDur(t.getDur() + te.getDur()); + } else { + map.put(n, te); + out.add(te); + } + } + } + + //Strip everything after ":" in "fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/..." + for( int i=0; i> p = aggregateTraceEvents(events); + final Map stats = p.getSecond(); + long allOpsUs = p.getFirst(); + + //Summarize by op type: + List l = new ArrayList<>(stats.keySet()); + Collections.sort(l, new Comparator() { + @Override + public int compare(String o1, String o2) { + return -Long.compare(stats.get(o1).getSumUs(), stats.get(o2).getSumUs()); + } + }); + + //Work out longest name and op name: + int longestName = 30; + int longestOpName = 30; + for (String s : l) { + longestName = Math.max(longestName, s.length() + 1); + longestOpName = Math.max(longestOpName, stats.get(s).getOpName().length() + 1); + } + + StringBuilder sb = new StringBuilder(); + String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n"; + sb.append(String.format(headerFormat, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std")); + String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n"; + for (String s : l) { + OpStats st = stats.get(s); + double pc = (100.0 * st.getSumUs()) / allOpsUs; + INDArray arr = st.getTimesUs().array(); + long min = arr.minNumber().longValue(); + long max = arr.maxNumber().longValue(); + double std = arr.stdNumber().doubleValue(); + sb.append(String.format(format, s, st.getOpName(), st.getCount(), st.getSumUs(), pc, min, max, std)); + } + + return sb.toString(); + } + + private static Pair> aggregateTraceEvents(TraceEvent[] events) { + //Summarize by op (instance) name: + final Map stats = new HashMap<>(); + for (TraceEvent e : events) { + if (e.getPh() != Phase.X || e.getDur() == null) { + continue; + } + + OpStats s; + String instanceName = (String) e.getArgs().get("name"); + if (stats.containsKey(instanceName)) { + s = stats.get(instanceName); + } else { + s = new OpStats(instanceName, e.getName(), 0, new NDArrayList(DataType.LONG, 0), null); + stats.put(instanceName, s); + } + s.setCount(s.getCount() + 1); + s.getTimesUs().add((double) e.getDur()); + } + + long allOpsUs = 0; + for (OpStats s : stats.values()) { + s.setSumUs( s.getTimesUs().array().sumNumber().longValue()); + allOpsUs += s.getSumUs(); + } + + return new Pair<>(allOpsUs, stats); + } + /** + * Compare the specified profile files, sorted by profile 1 % of total time + * + * @param file1 First profile file + * @param file2 Second profile file + * @param format1 Format of first profile + * @param format2 Format of second profile + * @return Comparison summary as a String + */ + public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2) { + return compareProfiles(file1, file2, format1, format2, false, false, null, null, SortBy.PROFILE1_PC); + } + + /** + * Compare the specified profile files or directory + * + * @param file1 First profile file or directory of profiles + * @param file2 Second profile file or directory of profiles + * @param format1 Format for first profile file/s + * @param format2 Format for second profile file/s + * @param firstIsDir True if the first File object is a directory + * @param secondIsDir True if the second File object is a directory + * @param name1 Name of the first profile (just for display purposes). Optional + * @param name2 Name of the second profile (just for display purposes). Optional + * @param sortBy What to sort the summary results by + * @return Comparison summary as a String + */ + public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2, + boolean firstIsDir, boolean secondIsDir, String name1, String name2, final SortBy sortBy) { + return compareProfiles(Config.builder() + .profile1(file1) + .profile2(file2) + .profile1Format(format1) + .profile2Format(format2) + .profile1IsDir(firstIsDir) + .profile2IsDir(secondIsDir) + .p1Name(name1) + .p2Name(name2) + .sortBy(sortBy) + .build()); + } + + public static String compareProfiles(final Config c){ + TraceEvent[] t1 = c.profile1IsDir() ? getTraceEventsDir(c.profile1(), c.profile1Format()) : getTraceEvents(c.profile1(), c.profile1Format()); + TraceEvent[] t2 = c.profile2IsDir() ? getTraceEventsDir(c.profile2(), c.profile2Format()) : getTraceEvents(c.profile2(), c.profile2Format()); + + final Pair> p1 = aggregateTraceEvents(t1); + final Pair> p2 = aggregateTraceEvents(t2); + + List l = new ArrayList<>(c.sortBy() != SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet()); + Collections.sort(l, new Comparator() { + @Override + public int compare(String o1, String o2) { + switch (c.sortBy()) { + case PROFILE1_PC: + return -Long.compare(p1.getSecond().get(o1).getSumUs(), p1.getSecond().get(o2).getSumUs()); + case PROFILE2_PC: + return -Long.compare(p2.getSecond().get(o1).getSumUs(), p2.getSecond().get(o2).getSumUs()); + case RATIO: + double m1a = meanTime(p1, o1); + double m1b = meanTime(p1, o2); + double m2a = meanTime(p2, o1); + double m2b = meanTime(p2, o2); + double ratio1 = m1a / m2a; + double ratio2 = m1b / m2b; + return -Double.compare(ratio1, ratio2); + default: + throw new RuntimeException(); + } + } + }); + + Set set = new HashSet<>(l); + + + StringBuilder sb = new StringBuilder(); + sb.append("1 = ").append(c.p1Name() == null ? "Profile 1" : c.p1Name()).append("\n") + .append("2 = ").append(c.p2Name() == null ? "Profile 2" : c.p2Name()).append("\n"); + + //Work out longest name and op name: + int longestName = 30; + int longestOpName = 30; + Map stats = c.sortBy() == SortBy.PROFILE2_PC ? p2.getSecond() : p1.getSecond(); + for (String s : l) { + longestName = Math.max(longestName, s.length() + 1); + longestOpName = Math.max(longestOpName, stats.get(s).getOpName().length() + 1); + } + + String headerFormat; + String format; + if(c.format() == null || c.format() == OutputFormat.TEXT){ + headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-16s%-13s%-13s%-14s%-14s%-12s%-12s%-14s%-14s%-10s%-10s%-10s%-10s\n"; + format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-16.2f%-13.2f%-13.2f%-14d%-14d%-12.2f%-12.2f%-14d%-14d%-10d%-10d%-10.2f%-10.2f\n"; + } else { + headerFormat = "%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s\n"; + format = "%s,%s,%d,%d,%.2f,%.2f,%.2f,%d,%d,%.2f,%.2f,%d,%d,%d,%d,%.2f,%.2f\n"; + } + sb.append(String.format(headerFormat, "Op Name", "Op", "Count (1)", "Count (2)", "Mean Ratio 1/2", "Mean (1)", "Mean (2)", "Total uS (1)", "Total uS (2)", "% (1)", "% (2)", "Min (1)", "Min (2)", "Max (1)", "Max (2)", "Std (1)", "Std (2)")); + + + for (String s : l) { + OpStats s1 = p1.getSecond().get(s); + OpStats s2 = p2.getSecond().get(s); + + if(c.filter() != null && !c.filter().apply(s1, s2)) + continue; + + double m1 = s1 == null ? 0 : s1.getTimesUs().array().meanNumber().doubleValue(); + double m2 = s2 == null ? 0 : s2.getTimesUs().array().meanNumber().doubleValue(); + double ratio = m1 / m2; + + double pc1 = s1 == null ? 0 : 100.0 * s1.getSumUs() / p1.getFirst(); + double pc2 = s2 == null ? 0 : 100.0 * s2.getSumUs() / p2.getFirst(); + + sb.append(String.format(format, s, s1 != null ? s1.getOpName() : s2.getOpName(), + s1 != null ? s1.getCount() : 0, + s2 != null ? s2.getCount() : 0, + //Ratio of means, means + ratio, + m1, m2, + //Total us, percent of op total + s1 != null ? s1.getSumUs() : 0, + s2 != null ? s2.getSumUs() : 0, + pc1, pc2, + //Min, max, std + s1 != null ? s1.getTimesUs().array().minNumber().longValue() : 0, + s2 != null ? s2.getTimesUs().array().minNumber().longValue() : 0, + s1 != null ? s1.getTimesUs().array().maxNumber().longValue() : 0, + s2 != null ? s2.getTimesUs().array().maxNumber().longValue() : 0, + s1 != null ? s1.getTimesUs().array().stdNumber().doubleValue() : 0.0, + s2 != null ? s2.getTimesUs().array().stdNumber().doubleValue() : 0.0)); + } + + boolean header = false; + String headerFormat2 = null; + String format3 = null; + List toAppend = null; + for (String s : (c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet())) { + + if (!set.contains(s)) { + Map m = c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond() : p2.getSecond(); + OpStats st = m.get(s); + if(c.filter() != null){ + OpStats other = c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond().get(s) : p2.getSecond().get(s); + boolean keep = c.filter().apply(other, st); + if(!keep) + continue; + } + + if (!header) { + toAppend = new ArrayList<>(); + + longestName = 30; + longestOpName = 30; + for(String s2 : m.keySet()){ + longestName = Math.max(longestName, s2.length()+1); + longestOpName = Math.max(longestOpName, m.get(s2).getOpName().length()+1); + } + if(c.format() == null || c.format() == OutputFormat.TEXT) { + headerFormat2 = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n"; + format3 = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n"; + } else { + headerFormat2 = "%s,%s,%s,%s,%s,%s,%s,%s\n"; + format3 = "%s,%s,%d,%d,%.2f,%d,%d,%.2f\n"; + } + + sb.append(" *** Operations not in profile ").append(c.sortBy() == SortBy.PROFILE2_PC ? "1" : "2").append(" but in profile ") + .append(c.sortBy() == SortBy.PROFILE2_PC ? "2" : "1").append(" ***\n"); + sb.append(String.format(headerFormat2, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std")); + header = true; + } + long allOpsUs = c.sortBy() == SortBy.PROFILE2_PC ? p1.getFirst() : p2.getFirst(); + double pc = (100.0 * st.getTimesUs().array().sumNumber().longValue()) / allOpsUs; + INDArray arr = st.getTimesUs().array(); + long min = arr.minNumber().longValue(); + long max = arr.maxNumber().longValue(); + double std = arr.stdNumber().doubleValue(); + toAppend.add(String.format(format3, s, st.getOpName(), st.getCount(), st.getSumUs(), pc, min, max, std)); + } + } + if(toAppend != null){ + Collections.sort(toAppend); + for(String s : toAppend){ + sb.append(s); + } + } + + return sb.toString(); + } + + private static double meanTime(Pair> p, String name) { + if (!p.getSecond().containsKey(name)) { + return 0.0; + } + return p.getSecond().get(name).getTimesUs().array().meanNumber().doubleValue(); + } + + + private static Map TF_PROFILE_ALIASES = new HashMap<>(); + + static { + TF_PROFILE_ALIASES.put("_MklSoftmax", "Softmax"); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/ColorName.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/ColorName.java new file mode 100644 index 000000000..0d9c08deb --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/ColorName.java @@ -0,0 +1,19 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.data; + +public enum ColorName { +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/Phase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/Phase.java new file mode 100644 index 000000000..bca7feb39 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/Phase.java @@ -0,0 +1,44 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.data; + +/** + * Chrome Profiler phase, for details see: + * + * https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit + */ +public enum Phase { + B, + E, + X, + I, + C, + b, + n, + e, + s, + t, + f, + P, + N, + O, + D, + M, + V, + v, + R, + c +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvent.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvent.java new file mode 100644 index 000000000..e4270edd1 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvent.java @@ -0,0 +1,53 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.data; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.Map; + +/** + * A TraceEvent, such as an operation execution.
+ * Intended mainly for JSON serialization/deserialization in Chrome profiler format
+ * Profiler format is described here: + * https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit + * See {@link org.nd4j.autodiff.listeners.profiler.ProfilingListener}
+ * See {@link org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer} + * + * @author Alex Black + */ +@Builder +@Data +@AllArgsConstructor +@NoArgsConstructor +public class TraceEvent { + + private String name; //Name of event (usually op name) + private List categories; //Comma separated list of categories + private Phase ph; //Event type - phase (see table for options) + private long ts; //Timestamp, in microseconds (us) + private Long dur; //Duration, optional + private Long tts; //Optional, thlread timestamp, in microseconds + private long pid; //Process ID + private long tid; //Thread ID + private Map args; //Args + private ColorName cname; //Optional, color name (must be one of reserved color names: https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html ) + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvents.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvents.java new file mode 100644 index 000000000..b3ebf6d8a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvents.java @@ -0,0 +1,34 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.data; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +/** + * A simple holder for a list of trace events + * + * @author Alex Black + */ +@AllArgsConstructor +@NoArgsConstructor +@Data +public class TraceEvents { + private List traceEvents; +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 449c2ef78..ceccdae65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -1804,7 +1804,7 @@ public class SameDiff extends SDBaseOps { if (validationData != null && (validationFrequency <= 0 || i % validationFrequency == 0)) { long validationStart = System.currentTimeMillis(); - outputHelper(validationData, new At(at.epoch(), 0, 0, 0, Operation.TRAINING_VALIDATION), + outputHelper(validationData, new At(at.epoch(), 0, 0, 0, null, Operation.TRAINING_VALIDATION), listenersWitHistory); long validationTime = System.currentTimeMillis() - validationStart; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index 1f93dbe94..c95f26b1f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -921,32 +921,6 @@ public abstract class AbstractSession { } } - /** - * FrameIter: Identifies a frame + iteration (but not a specific op or variable).
- * Note that frames can be nested - which generally represents nested loop situations. - */ - @Data - @AllArgsConstructor - public static class FrameIter { - private String frame; - private int iteration; - private FrameIter parentFrame; - - @Override - public String toString() { - return "(\"" + frame + "\"," + iteration + (parentFrame == null ? "" : ",parent=" + parentFrame.toString()) + ")"; - } - - @Override - public FrameIter clone() { - return new FrameIter(frame, iteration, (parentFrame == null ? null : parentFrame.clone())); - } - - public VarId toVarId(String name) { - return new VarId(name, frame, iteration, parentFrame); - } - } - /** * ExecType: Execution type, as used in ExecStep
* OP: Operation execution
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/FrameIter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/FrameIter.java new file mode 100644 index 000000000..4ca555327 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/FrameIter.java @@ -0,0 +1,46 @@ +/* ****************************************************************************** + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.autodiff.samediff.internal; + +import lombok.AllArgsConstructor; +import lombok.Data; + +/** + * FrameIter: Identifies a frame + iteration (but not a specific op or variable).
+ * Note that frames can be nested - which generally represents nested loop situations. + */ +@Data +@AllArgsConstructor +public class FrameIter { + private String frame; + private int iteration; + private FrameIter parentFrame; + + @Override + public String toString() { + return "(\"" + frame + "\"," + iteration + (parentFrame == null ? "" : ",parent=" + parentFrame.toString()) + ")"; + } + + @Override + public FrameIter clone() { + return new FrameIter(frame, iteration, (parentFrame == null ? null : parentFrame.clone())); + } + + public AbstractSession.VarId toVarId(String name) { + return new AbstractSession.VarId(name, frame, iteration, parentFrame); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 4a6a5ce53..9b8d751eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -206,6 +206,7 @@ public class InferenceSession extends AbstractSession { @Override public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, Set constAndPhInputs, List listeners, At at, MultiDataSet batch, Set allReqVariables) { + at.setFrameIter(outputFrameIter); if (listeners != null && listeners.size() > 0) { SameDiffOp sdOp = sameDiff.getOps().get(op.getOp().getOwnName()); for (Listener l : listeners) { @@ -363,6 +364,11 @@ public class InferenceSession extends AbstractSession { String[] argNames = s.argNames(); //Order: input, boolean array VarId vidPredicate = outputFrameIter.toVarId(argNames[1]); INDArray predicate = this.nodeOutputs.get(vidPredicate); + if(predicate == null && !constAndPhInputs.isEmpty() && constAndPhInputs.contains(argNames[1])){ + //Constant predicate... + predicate = this.nodeOutputs.get(new VarId(argNames[1], OUTER_FRAME, 0, null)); + } + Preconditions.checkNotNull(predicate, "Error during graph execution: Predicate array was null. VarId=%s", vidPredicate); Preconditions.checkState(predicate.isScalar() && predicate.dataType() == DataType.BOOL, "Expected boolean predicate: got %ndSInfo", predicate); VarId vid = outputFrameIter.toVarId(argNames[0]); if (predicate.getDouble(0) == 0.0) { @@ -477,11 +483,11 @@ public class InferenceSession extends AbstractSession { } throw new IllegalStateException(s); } - return ((Assert) op).outputArguments(); + return ((Assert) op).outputArguments().toArray(new INDArray[0]); } else if (op instanceof CustomOp) { CustomOp c = (CustomOp) op; Nd4j.exec(c); - return c.outputArguments(); + return c.outputArguments().toArray(new INDArray[0]); } else if (op instanceof Op) { Op o = (Op) op; Nd4j.exec(o); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 8d2e9f624..f7d4a85bc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -184,55 +184,7 @@ public abstract class SDBaseOps { public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { return argmin(null, in, keepDims, dimensions); } - - /** - * Assign/copy op: out = x.assign(y). Supports broadcasting - * - * @param x Input variable x - * @param y Input variable y - * @return Output variable - */ - public SDVariable assign(SDVariable x, SDVariable y) { - return assign(null, x, y); - } - - /** - * Assign/copy op: out = x.assign(y). Supports broadcasting - * - * @param name Name of the output variable - * @param x Input variable x - * @param y Input variable y - * @return Output variable - */ - public SDVariable assign(String name, SDVariable x, SDVariable y) { - SDVariable ret = f().assign(x, y); - return updateVariableNameAndReference(ret, name); - } - - /** - * Return an array with equal shape to the input, but all elements set to 'value' - * - * @param in Input variable - * @param value Value to set - * @return Output variable - */ - public SDVariable assign(SDVariable in, Number value) { - return assign(null, in, value); - } - - /** - * Return an array with equal shape to the input, but all elements set to 'value' - * - * @param name Name of the output variable - * @param in Input variable - * @param value Value to set - * @return Output variable - */ - public SDVariable assign(String name, SDVariable in, Number value) { - SDVariable ret = f().assign(in, value); - return updateVariableNameAndReference(ret, name); - } - + /** * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same * length and each pair taken from these sets has to have dimensions (M, N) and (N, K), diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index bf71a665e..7b662b960 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -3,10 +3,7 @@ package org.nd4j.autodiff.samediff.ops; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.custom.AdjustContrast; -import org.nd4j.linalg.api.ops.custom.AdjustHue; -import org.nd4j.linalg.api.ops.custom.AdjustSaturation; -import org.nd4j.linalg.api.ops.custom.RandomCrop; +import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; @@ -119,4 +116,70 @@ public class SDImage extends SDOps { SDVariable out = new RandomCrop(sd, input, shape).outputVariable(); return updateVariableNameAndReference(out, name); } + + /** + * Converting array from HSV to RGB format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable rgbToHsv(String name, @NonNull SDVariable input) { + SDVariable out = new RgbToHsv(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Converting image from HSV to RGB format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable hsvToRgb(String name, @NonNull SDVariable input) { + SDVariable out = new HsvToRgb(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Converting array from RGB to YIQ format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable rgbToYiq(String name, @NonNull SDVariable input) { + SDVariable out = new RgbToYiq(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Converting image from YIQ to RGB format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable yiqToRgb(String name, @NonNull SDVariable input) { + SDVariable out = new YiqToRgb(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Converting array from RGB to YUV format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable rgbToYuv(String name, @NonNull SDVariable input) { + SDVariable out = new RgbToYuv(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Converting image from YUV to RGB format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable yuvToRgb(String name, @NonNull SDVariable input) { + SDVariable out = new YuvToRgb(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java index 70da070b8..f0e94a4e5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ***************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -34,11 +34,20 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.*; * * @author Alex Black */ +@SuppressWarnings("unused") public class SDLoss extends SDOps { public SDLoss(SameDiff sameDiff) { super(sameDiff); } + /** + * helper to refactor duplicate code + */ + private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){ + String weightName = (name == null) ? null : name + "/weight"; + return (weights == null) ? sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)) : weights; + } + /** * See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}. */ @@ -60,12 +69,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce) { validateFloatingPoint("absolute difference loss", "predictions", predictions); validateNumerical("absolute difference loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -105,12 +109,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce, int dimension) { validateFloatingPoint("cosine distance loss", "predictions", predictions); validateNumerical("cosine distance loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -192,12 +191,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce, double delta) { validateFloatingPoint("huber loss", "predictions", predictions); validateNumerical("huber loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -258,12 +252,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) { validateFloatingPoint("log loss", "predictions", predictions); validateNumerical("log loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -299,12 +288,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce) { validateFloatingPoint("log poisson loss", "predictions", predictions); validateNumerical("log poisson loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -341,12 +325,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce) { validateFloatingPoint("log poisson (full) loss", "predictions", predictions); validateNumerical("log poisson (full) loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -382,12 +361,7 @@ public class SDLoss extends SDOps { public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) { validateFloatingPoint("main pairwise squared error loss", "predictions", predictions); validateNumerical("mean pairwise squared error loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -417,12 +391,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce) { validateFloatingPoint("mean squared error loss", "predictions", predictions); validateNumerical("mean squared error loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -468,12 +437,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) { validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits); validateNumerical("sigmoid cross entropy loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictionLogits.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictionLogits); SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -518,12 +482,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) { validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions); validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(logitPredictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, logitPredictions); SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -595,6 +554,4 @@ public class SDLoss extends SDOps { result.markAsLoss(); return result; } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java index 66e52c151..879c08cba 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java @@ -19,6 +19,7 @@ package org.nd4j.autodiff.samediff.ops; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; @@ -295,4 +296,43 @@ public class SDRandom extends SDOps { return updateVariableNameAndReference(ret, name); } + /** + * Generate a new random SDVariable with Gamma distribution + * + * @param name Name of the output variable + * @param alpha distribution parameter + * @param beta distribution parameter + * @param shape Shape of the new variable + * @return new SDVariable + */ + public SDVariable gamma(String name, SDVariable shape, SDVariable alpha, SDVariable beta) { + SDVariable ret = f().randomGamma(alpha, beta, shape); + return updateVariableNameAndReference(ret, name); + } + + /** + * Generate a new random SDVariable with Poission distribution + * + * @param name Name of the output variable + * @param lambda rate distribution parameter + * @param shape Shape of the new variable + * @return new SDVariable + */ + public SDVariable poisson(String name, SDVariable lambda, SDVariable shape, int... seeds) { + SDVariable ret = f().randomPoisson(shape, lambda, seeds); + return updateVariableNameAndReference(ret, name); + } + + /** + * Generate a new random SDVariable by random shuffle + * + * @param name Name of the output variable + * @param value array to shuffle + * @return new SDVariable + */ + public SDVariable shuffle(String name, SDVariable value, int... seeds) { + SDVariable ret = f().randomShuffle(value, seeds); + return updateVariableNameAndReference(ret, name); + } + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index 5b4cb497b..a88a9c84f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -274,7 +274,6 @@ public class FlatBuffersMapper { return OpType.TRANSFORM_STRICT; case SPECIAL: return OpType.TRANSFORM_STRICT; - case VARIANCE: case REDUCE_FLOAT: return OpType.REDUCE_FLOAT; case REDUCE_BOOL: @@ -302,6 +301,7 @@ public class FlatBuffersMapper { case PAIRWISE_BOOL: return OpType.PAIRWISE_BOOL; case SUMMARYSTATS: + case VARIANCE: return OpType.SUMMARYSTATS; default: throw new UnsupportedOperationException("Unknown op type passed in: " + type); @@ -799,7 +799,8 @@ public class FlatBuffersMapper { } int[] dims; - if (node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_BOOL || node.opType() == Op.Type.REDUCE_LONG || node.opType() == Op.Type.INDEXREDUCE || node.opType() == Op.Type.REDUCE3) { + if (node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_BOOL + || node.opType() == Op.Type.REDUCE_LONG || node.opType() == Op.Type.INDEXREDUCE || node.opType() == Op.Type.REDUCE3) { dims = node.getDimensions(); if (dims == null) dims = new int[0]; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java index 0f1e0bd52..b35369ad4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java @@ -47,7 +47,7 @@ public class GradCheckUtil { public enum Subset {EVERY_N, RANDOM} - public static final boolean DEFAULT_PRINT = true; + public static final boolean DEFAULT_PRINT = false; public static final boolean DEFAULT_EXIT_FIRST_FAILURE = false; public static final boolean DEFAULT_DEBUG_MODE = false; public static final double DEFAULT_EPS = 1e-5; @@ -330,11 +330,10 @@ public class GradCheckUtil { + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsError); } } else { - if (print) - log.info("Param " + i + " (" + name + strIdx + ") FAILED: grad= " + analyticGrad - + ", numericalGrad= " + numericalGrad + ", relError= " + relError - + ", absError=" + absError - + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus); + log.info("Param " + i + " (" + name + strIdx + ") FAILED: grad= " + analyticGrad + + ", numericalGrad= " + numericalGrad + ", relError= " + relError + + ", absError=" + absError + + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus); if (exitOnFirstFailure) return false; totalNFailures++; @@ -347,11 +346,9 @@ public class GradCheckUtil { } } - if (print) { - int nPass = totalCount - totalNFailures; - log.info("GradCheckUtil.checkGradients(): " + totalCount + " params checked, " + nPass + " passed, " - + totalNFailures + " failed. Largest relative error = " + maxError); - } + int nPass = totalCount - totalNFailures; + log.info("GradCheckUtil.checkGradients(): " + totalCount + " params checked, " + nPass + " passed, " + + totalNFailures + " failed. Largest relative error = " + maxError); if(debugMode && !debugBefore){ sd.disableDebugging(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index fc7572180..d57ab7c97 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -457,7 +457,7 @@ public class OpValidation { for (int i = 0; i < testCase.testFns().size(); i++) { String error; try { - error = testCase.testFns().get(i).apply(testCase.op().outputArguments()[i]); + error = testCase.testFns().get(i).apply(testCase.op().outputArguments().get(i)); } catch (Throwable t) { throw new IllegalStateException("Exception thrown during op output validation for output " + i, t); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java index fad760bb3..2767a22f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java @@ -45,7 +45,7 @@ import java.util.*; public class TestCase { public enum TestSerialization {BEFORE_EXEC, AFTER_EXEC, BOTH, NONE}; - public static final boolean GC_DEFAULT_PRINT = true; + public static final boolean GC_DEFAULT_PRINT = false; public static final boolean GC_DEFAULT_EXIT_FIRST_FAILURE = false; public static final boolean GC_DEFAULT_DEBUG_MODE = false; public static final double GC_DEFAULT_EPS = 1e-5; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java index 7e7a50ab2..9eee099a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java @@ -1,6 +1,7 @@ package org.nd4j.autodiff.validation.listeners; import lombok.Getter; +import lombok.val; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; import org.nd4j.autodiff.listeners.Operation; @@ -50,12 +51,12 @@ public class NonInplaceValidationListener extends BaseListener { opInputs = new INDArray[]{o.x().dup(), o.y().dup()}; } } else if(op.getOp() instanceof DynamicCustomOp){ - INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments(); - opInputs = new INDArray[arr.length]; - opInputsOrig = new INDArray[arr.length]; - for( int i=0; i opNames; + /** The number of times each operation was observed in all graphs */ + private final Map opCounts; /** The (unique) names of all ops that were encountered, and can be imported, in all graphs */ private final Set importSupportedOpNames; /** The (unique) names of all ops that were encountered, and can NOT be imported (lacking import mapping) */ @@ -60,6 +62,11 @@ public class TFImportStatus { Set newOpNames = new HashSet<>(opNames); newOpNames.addAll(other.opNames); + Map newOpCounts = new HashMap<>(opCounts); + for(Map.Entry e : other.opCounts.entrySet()){ + newOpCounts.put(e.getKey(), (newOpCounts.containsKey(e.getKey()) ? newOpCounts.get(e.getKey()) : 0) + e.getValue()); + } + Set newImportSupportedOpNames = new HashSet<>(importSupportedOpNames); newImportSupportedOpNames.addAll(other.importSupportedOpNames); @@ -89,6 +96,7 @@ public class TFImportStatus { totalNumOps + other.totalNumOps, countUnique, newOpNames, + newOpCounts, newImportSupportedOpNames, newUnsupportedOpNames, newUnsupportedOpModels); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java index 39d8e1577..0493099f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java @@ -230,6 +230,7 @@ public class TensorFlowImportValidator { try { int opCount = 0; Set opNames = new HashSet<>(); + Map opCounts = new HashMap<>(); try(InputStream bis = new BufferedInputStream(is)) { GraphDef graphDef = GraphDef.parseFrom(bis); @@ -248,6 +249,8 @@ public class TensorFlowImportValidator { String op = nd.getOp(); opNames.add(op); + int soFar = opCounts.containsKey(op) ? opCounts.get(op) : 0; + opCounts.put(op, soFar + 1); opCount++; } } @@ -282,6 +285,7 @@ public class TensorFlowImportValidator { opCount, opNames.size(), opNames, + opCounts, importSupportedOpNames, unsupportedOpNames, unsupportedOpModel); @@ -297,6 +301,7 @@ public class TensorFlowImportValidator { 0, 0, Collections.emptySet(), + Collections.emptyMap(), Collections.emptySet(), Collections.emptySet(), Collections.>emptyMap()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java index 73321a1c1..08bbbb997 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java @@ -1,5 +1,6 @@ /* ***************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -61,16 +62,19 @@ public class ActivationPReLU extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdalpha = Nd4j.create(alpha.shape()); + INDArray dLdalpha = alpha.ulike(); + INDArray outTemp = in.ulike(); DynamicCustomOp.DynamicCustomOpsBuilder preluBp = DynamicCustomOp.builder("prelu_bp") - .addInputs(in, alpha, epsilon).addOutputs(in, alpha); + .addInputs(in, alpha, epsilon) + .addOutputs(outTemp, dLdalpha); if (sharedAxes != null) { for (long axis: sharedAxes) { preluBp.addIntegerArguments(axis); } } - Nd4j.getExecutioner().execAndReturn(preluBp.build()); + Nd4j.exec(preluBp.build()); + in.assign(outTemp); return new Pair<>(in, dLdalpha); } @@ -78,4 +82,4 @@ public class ActivationPReLU extends BaseActivationFunction { public String toString() { return "prelu"; } -} +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java index 2bafd5472..02fe528c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java @@ -18,11 +18,13 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; -import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.impl.scalar.*; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; import org.nd4j.linalg.factory.Nd4j; /** @@ -32,16 +34,72 @@ import org.nd4j.linalg.factory.Nd4j; @Getter public class ActivationReLU extends BaseActivationFunction { + private Double max; + private Double threshold; + private Double negativeSlope; + + public ActivationReLU(){ + this(null, null, null); + } + + public ActivationReLU(Double maxValue, Double threshold, Double negativeSlope){ + this.max = maxValue; + this.threshold = threshold; + this.negativeSlope = negativeSlope; + } + @Override public INDArray getActivation(INDArray in, boolean training) { - Nd4j.getExecutioner().execAndReturn(new RectifiedLinear(in)); + if(negativeSlope != null || threshold != null){ + double t = threshold == null ? 0.0 : threshold; + double ns = negativeSlope == null ? 0.0 : negativeSlope; + if(t == 0.0) { + Nd4j.getExecutioner().execAndReturn(new LeakyReLU(in, ns)); + } else { + //Non-zero threshold, and non-zero slope + //TODO optimize this... but, extremely rare case in practice? + INDArray oneGte = in.gte(t).castTo(in.dataType()); + INDArray oneLt = in.lt(t).castTo(in.dataType()); + INDArray lower = oneLt.muli(ns).muli(in.sub(threshold)); + INDArray upper = oneGte.muli(in); + in.assign(lower.addi(upper)); + } + } else { + Nd4j.getExecutioner().exec(new RectifiedLinear(in, in)); + } + if(max != null){ + Nd4j.exec(new ScalarMin(in, null, in, max)); + } return in; } @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.exec(new RectifiedLinearDerivative(in, epsilon, in.ulike()))[0]; + + INDArray dLdz; + INDArray maxMask = (max == null || max == 0.0 ? null : in.lt(max)); + if(negativeSlope != null || threshold != null){ + double t = threshold == null ? 0.0 : threshold; + double ns = negativeSlope == null ? 0.0 : negativeSlope; + if(t == 0.0) { + dLdz = Nd4j.getExecutioner().exec(new LeakyReLUBp(in, epsilon, in.ulike(), ns))[0]; + } else { + //Non-zero threshold, and non-zero slope + //TODO optimize this... but, extremely rare case in practice? + INDArray oneGte = in.gte(t).castTo(in.dataType()); + INDArray oneLt = in.lt(t).castTo(in.dataType()); + INDArray lower = oneLt.muli(ns); + INDArray upper = oneGte; + dLdz = in.assign(lower.addi(upper)).muli(epsilon); + } + } else { + dLdz = Nd4j.getExecutioner().exec(new RectifiedLinearDerivative(in, epsilon, in.ulike(), threshold == null ? 0.0 : threshold))[0]; + } + + if(maxMask != null){ + dLdz.muli(maxMask); + } return new Pair<>(dLdz, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 77b946559..ab622b34f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -23,7 +23,6 @@ import com.google.flatbuffers.FlatBufferBuilder; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; -import net.ericaro.neoitertools.Generator; import org.apache.commons.math3.util.FastMath; import org.bytedeco.javacpp.BytePointer; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; @@ -998,14 +997,14 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - Pair tadInfo = - Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension); + Pair tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension); DataBuffer shapeInfo = tadInfo.getFirst(); - val shape = Shape.shape(shapeInfo); - val stride = Shape.stride(shapeInfo).asLong(); + val jShapeInfo = shapeInfo.asLong(); + val shape = Shape.shape(jShapeInfo); + val stride = Shape.stride(jShapeInfo); long offset = offset() + tadInfo.getSecond().getLong(index); - val ews = shapeInfo.getLong(shapeInfo.getLong(0) * 2 + 2); - char tadOrder = (char) shapeInfo.getInt(shapeInfo.getLong(0) * 2 + 3); + val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2); + char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3); val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder); return toTad; } @@ -2217,9 +2216,10 @@ public abstract class BaseNDArray implements INDArray, Iterable { if(isEmpty() || isS()) return false; - return Shape.offset(jvmShapeInfo.javaShapeInformation) > 0 - || (length() < data().length() && data.dataType() != DataType.INT) - || data().originalDataBuffer() != null; + val c2 = (length() < data().length() && data.dataType() != DataType.INT); + val c3 = (data().originalDataBuffer() != null && data != data.originalDataBuffer()); + + return c2 || c3; } @Override @@ -3585,6 +3585,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { case DOUBLE: case FLOAT: case HALF: + case BFLOAT16: return getDouble(i); case LONG: case INT: @@ -3592,6 +3593,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { case UBYTE: case BYTE: case BOOL: + case UINT64: + case UINT32: + case UINT16: return getLong(i); case UTF8: case COMPRESSED: @@ -4350,29 +4354,30 @@ public abstract class BaseNDArray implements INDArray, Iterable { //epsilon equals if (isScalar() && n.isScalar()) { - if (data.dataType() == DataType.FLOAT) { - double val = getDouble(0); - double val2 = n.getDouble(0); + if (isZ()) { + val val = getLong(0); + val val2 = n.getLong(0); + + return val == val2; + } else if (isR()) { + val val = getDouble(0); + val val2 = n.getDouble(0); if (Double.isNaN(val) != Double.isNaN(val2)) return false; return Math.abs(val - val2) < eps; - } else { - double val = getDouble(0); - double val2 = n.getDouble(0); + } else if (isB()) { + val val = getInt(0); + val val2 = n.getInt(0); - if (Double.isNaN(val) != Double.isNaN(val2)) - return false; - - return Math.abs(val - val2) < eps; + return val == val2; } } else if (isVector() && n.isVector()) { - - EqualsWithEps op = new EqualsWithEps(this, n, eps); - Nd4j.getExecutioner().exec(op); - double diff = op.z().getDouble(0); + val op = new EqualsWithEps(this, n, eps); + Nd4j.exec(op); + val diff = op.z().getDouble(0); return diff < 0.5; } @@ -4750,8 +4755,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { return this; checkArrangeArray(rearrange); - int[] newShape = doPermuteSwap(shapeOf(), rearrange); - int[] newStride = doPermuteSwap(strideOf(), rearrange); + val newShape = doPermuteSwap(shape(), rearrange); + val newStride = doPermuteSwap(stride(), rearrange); char newOrder = Shape.getOrder(newShape, newStride, 1); @@ -4777,23 +4782,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { return this; checkArrangeArray(rearrange); - val newShape = doPermuteSwap(Shape.shapeOf(shapeInfo), rearrange); - val newStride = doPermuteSwap(Shape.stride(shapeInfo), rearrange); + val newShape = doPermuteSwap(shape(), rearrange); + val newStride = doPermuteSwap(stride(), rearrange); char newOrder = Shape.getOrder(newShape, newStride, 1); - //Set the shape information of this array: shape, stride, order. - //Shape info buffer: [rank, [shape], [stride], offset, elementwiseStride, order] - /*for( int i=0; i outputArguments(); - - - INDArray[] outputArguments(); - - INDArray[] inputArguments(); + List inputArguments(); long[] iArgs(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 99e930176..e46dfab4b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -261,19 +261,13 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } @Override - public INDArray[] outputArguments() { - if (!outputArguments.isEmpty()) { - return outputArguments.toArray(new INDArray[0]); - } - return new INDArray[0]; + public List outputArguments() { + return outputArguments; } @Override - public INDArray[] inputArguments() { - if (!inputArguments.isEmpty()) - return inputArguments.toArray(new INDArray[0]); - return new INDArray[0]; - + public List inputArguments() { + return inputArguments; } @Override @@ -367,10 +361,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { for (int i = 0; i < args.length; i++) { // it's possible to get into situation where number of args > number of arrays AT THIS MOMENT - if (i >= arrsSoFar.length) + if (i >= arrsSoFar.size()) continue; - if (!Arrays.equals(args[i].getShape(), arrsSoFar[i].shape())) + if (!Arrays.equals(args[i].getShape(), arrsSoFar.get(i).shape())) throw new ND4JIllegalStateException("Illegal array passed in as argument [" + i + "]. Expected shape " + Arrays.toString(args[i].getShape()) + " and received array with shape " + Arrays.toString(arg[i].shape())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index e66d52f91..dda6aef24 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -136,4 +136,10 @@ public interface OpContext extends AutoCloseable { * @param reallyAllow */ void allowHelpers(boolean reallyAllow); + + /** + * This methos allows to disape outputs validation via shape function + * @param reallyOverride + */ + void shapeFunctionOverride(boolean reallyOverride); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatSparseToDense.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatSparseToDense.java new file mode 100644 index 000000000..18293c2ee --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatSparseToDense.java @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.compat; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * This is a wrapper for SparseToDense op that impelements corresponding TF operation + * + * @author raver119@gmail.com + */ +public class CompatSparseToDense extends DynamicCustomOp { + + public CompatSparseToDense() { + // + } + + public CompatSparseToDense(INDArray indices, INDArray shape, INDArray values) { + Preconditions.checkArgument(shape.isZ() && indices.isZ(), "Shape & indices arrays must have one integer data types"); + inputArguments.add(indices); + inputArguments.add(shape); + inputArguments.add(values); + } + + public CompatSparseToDense(INDArray indices, INDArray shape, INDArray values, INDArray defaultVaule) { + this(indices, shape, values); + Preconditions.checkArgument(defaultVaule.dataType() == values.dataType(), "Values array must have the same data type as defaultValue array"); + inputArguments.add(defaultVaule); + } + + @Override + public String opName() { + return "compat_sparse_to_dense"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatStringSplit.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatStringSplit.java new file mode 100644 index 000000000..33b6df4a6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatStringSplit.java @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.compat; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * This is a wrapper for StringSplit op that impelements corresponding TF operation + * + * @author raver119@gmail.com + */ +public class CompatStringSplit extends DynamicCustomOp { + + public CompatStringSplit() { + // + } + + public CompatStringSplit(INDArray strings, INDArray delimiter) { + Preconditions.checkArgument(strings.isS() && delimiter.isS(), "Input arrays must have one of UTF types"); + inputArguments.add(strings); + inputArguments.add(delimiter); + } + + public CompatStringSplit(INDArray strings, INDArray delimiter, INDArray indices, INDArray values) { + this(strings, delimiter); + + outputArguments.add(indices); + outputArguments.add(values); + } + + @Override + public String opName() { + return "compat_string_split"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Digamma.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Digamma.java new file mode 100644 index 000000000..206d0027f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Digamma.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Digamma extends DynamicCustomOp { + public Digamma(@NonNull INDArray x) { + addInputArgument(x); + } + + public Digamma(@NonNull SameDiff sameDiff, @NonNull SDVariable x) { + super("", sameDiff, new SDVariable[]{x}); + } + + @Override + public String opName() { + return "digamma"; + } + + @Override + public String tensorflowName() { + return "Digamma"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java index 691e5d43f..3bbab11be 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java @@ -19,15 +19,23 @@ import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; +import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; public class FusedBatchNorm extends DynamicCustomOp { + private DataType outputDataType; + public FusedBatchNorm() {} public FusedBatchNorm(@NonNull INDArray x, @NonNull INDArray scale, @NonNull INDArray offset, @@ -38,6 +46,7 @@ public class FusedBatchNorm extends DynamicCustomOp { if (yOut != null && batchMeanOut != null && batchMeanVar != null) { addOutputArgument(yOut, batchMeanOut, batchMeanVar); } + this.outputDataType = x.dataType(); } public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset, @@ -51,14 +60,25 @@ public class FusedBatchNorm extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "FusedBatchNormV2"; + public String[] tensorflowNames() { + return new String[]{"FusedBatchNormV2","FusedBatchNormV3"}; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + boolean isNchw = attributesForNode.containsKey("data_format") && attributesForNode.get("data_format").getS().toStringUtf8().equalsIgnoreCase("NCHW"); + boolean training = !attributesForNode.containsKey("is_training") ? true : attributesForNode.get("is_training").getB(); + addIArgument(isNchw ? 1 : 0); + addIArgument(training ? 1 : 0); + if(attributesForNode.containsKey("T")){ + outputDataType = TFGraphMapper.convertType(attributesForNode.get("T").getType()); + } } @Override public List calculateOutputDataTypes(List inputDataTypes){ int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - return Collections.singletonList(inputDataTypes.get(0)); + return Arrays.asList(outputDataType, DataType.FLOAT, DataType.FLOAT); //Activations may be half, bfloat16, float32; mean/var is always float } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/HsvToRgb.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/HsvToRgb.java new file mode 100644 index 000000000..6b1361376 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/HsvToRgb.java @@ -0,0 +1,57 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class HsvToRgb extends DynamicCustomOp { + + public HsvToRgb(INDArray input) { + addInputArgument(input); + } + + public HsvToRgb(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "hsv_to_rgb"; + } + + @Override + public String tensorflowName() { + return "HSVToRGB"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igamma.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igamma.java new file mode 100644 index 000000000..e8efddb12 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igamma.java @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Igamma extends DynamicCustomOp { + public Igamma(@NonNull INDArray n, @NonNull INDArray x) { + Preconditions.checkArgument(n.shape() != x.shape(), + "Igamma: n and x must have the same shapes"); + addInputArgument(n,x); + } + + public Igamma(@NonNull INDArray n, @NonNull INDArray x, INDArray output) { + this(n,x); + if (output != null) { + addOutputArgument(output); + } + } + + public Igamma(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) { + super("", sameDiff, new SDVariable[]{n ,x}); + } + + @Override + public String opName() { + return "igamma"; + } + + @Override + public String tensorflowName() { + return "Igamma"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igammac.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igammac.java new file mode 100644 index 000000000..915a57764 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igammac.java @@ -0,0 +1,66 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Igammac extends DynamicCustomOp { + public Igammac(@NonNull INDArray n, @NonNull INDArray x) { + Preconditions.checkArgument(n.shape() != x.shape(), + "Igamma: n and x must have the same shapes"); + addInputArgument(n,x); + } + + public Igammac(@NonNull INDArray n, @NonNull INDArray x, INDArray output) { + this(n,x); + if (output != null) { + addOutputArgument(output); + } + } + + public Igammac(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) { + super("", sameDiff, new SDVariable[]{n ,x}); + } + + @Override + public String opName() { + return "igammac"; + } + + @Override + public String tensorflowName() { + return "Igammac"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lgamma.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lgamma.java new file mode 100644 index 000000000..3df488120 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lgamma.java @@ -0,0 +1,64 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Lgamma extends DynamicCustomOp { + + public Lgamma(@NonNull INDArray x) { + addInputArgument(x); + } + + public Lgamma(@NonNull INDArray x, INDArray output) { + this(x); + if (output != null) { + addOutputArgument(output); + } + } + + public Lgamma(@NonNull SameDiff sameDiff, @NonNull SDVariable x) { + super("", sameDiff, new SDVariable[]{x}); + } + + @Override + public String opName() { + return "lgamma"; + } + + @Override + public String tensorflowName() { + return "Lgamma"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java new file mode 100644 index 000000000..af1bf0155 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java @@ -0,0 +1,69 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +@NoArgsConstructor +public class Lu extends DynamicCustomOp { + private DataType indexDataType; + + public Lu(INDArray input) { + addInputArgument(input); + } + + public Lu(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "lu"; + } + + @Override + public String tensorflowName() { + return "Lu"; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + if (attributesForNode.containsKey("output_idx_type")){ + indexDataType = TFGraphMapper.convertType(attributesForNode.get("output_idx_type").getType()); + } + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Arrays.asList(inputDataTypes.get(0), indexDataType); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToGrayscale.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToGrayscale.java new file mode 100644 index 000000000..6b71ba17f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToGrayscale.java @@ -0,0 +1,44 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +@NoArgsConstructor +public class RgbToGrayscale extends DynamicCustomOp { + + public RgbToGrayscale(INDArray image) { + addInputArgument(image); + } + + public RgbToGrayscale(SameDiff sameDiff, SDVariable image) { + super(sameDiff, new SDVariable[]{image}); + } + + @Override + public String opName() { + return "rgb_to_grs"; + } + + @Override + public String tensorflowName() { + return "RgbToGrs"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToHsv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToHsv.java new file mode 100644 index 000000000..d96981460 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToHsv.java @@ -0,0 +1,57 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class RgbToHsv extends DynamicCustomOp { + + public RgbToHsv(INDArray input) { + addInputArgument(input); + } + + public RgbToHsv(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "rgb_to_hsv"; + } + + @Override + public String tensorflowName() { + return "RGBToHSV"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYiq.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYiq.java new file mode 100644 index 000000000..1d6a48a4f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYiq.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class RgbToYiq extends DynamicCustomOp { + + public RgbToYiq(INDArray input) { + addInputArgument(input); + } + + public RgbToYiq(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "rgb_to_yiq"; + } + + @Override + public String tensorflowName() { + return "RgbToYiq"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYuv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYuv.java new file mode 100644 index 000000000..c65c6e777 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYuv.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class RgbToYuv extends DynamicCustomOp { + public RgbToYuv(INDArray input) { + addInputArgument(input); + } + + public RgbToYuv(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "rgb_to_yuv"; + } + + @Override + public String tensorflowName() { + return "RgbToYuv"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java index cb805a775..83020cb57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java @@ -107,12 +107,12 @@ public class ScatterUpdate implements CustomOp { } @Override - public INDArray[] outputArguments() { + public List outputArguments() { return op.outputArguments(); } @Override - public INDArray[] inputArguments() { + public List inputArguments() { return op.inputArguments(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriangularSolve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriangularSolve.java new file mode 100644 index 000000000..7423d3a91 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriangularSolve.java @@ -0,0 +1,43 @@ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class TriangularSolve extends DynamicCustomOp { + + public TriangularSolve(INDArray matrix, INDArray rhs, boolean lower, boolean adjoint) { + addInputArgument(matrix, rhs); + addBArgument(lower, adjoint); + } + + public TriangularSolve(SameDiff sameDiff, SDVariable matrix, SDVariable rhs, + SDVariable lower, SDVariable adjoint) { + super(sameDiff, new SDVariable[] {matrix, rhs, lower, adjoint}); + } + + @Override + public String opName() { + return "triangular_solve"; + } + + @Override + public String tensorflowName() { + return "MatrixTriangularSolve"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + int n = args().length; + Preconditions.checkState(dataTypes != null && dataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), dataTypes); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YiqToRgb.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YiqToRgb.java new file mode 100644 index 000000000..8126a1803 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YiqToRgb.java @@ -0,0 +1,55 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class YiqToRgb extends DynamicCustomOp { + public YiqToRgb(INDArray input) { + addInputArgument(input); + } + + public YiqToRgb(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "yiq_to_rgb"; + } + + @Override + public String tensorflowName() { + return "YiqToRgb"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YuvToRgb.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YuvToRgb.java new file mode 100644 index 000000000..4643ec3fe --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YuvToRgb.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class YuvToRgb extends DynamicCustomOp { + public YuvToRgb(INDArray input) { + addInputArgument(input); + } + + public YuvToRgb(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "yuv_to_rgb"; + } + + @Override + public String tensorflowName() { + return "YuvToRgb"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java index 57606e452..aea251ebd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java @@ -23,7 +23,6 @@ import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.environment.Nd4jEnvironment; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -172,7 +171,7 @@ public class DefaultOpExecutioner implements OpExecutioner { @Override public INDArray[] exec(CustomOp op) { - return execAndReturn(op).outputArguments(); + return execAndReturn(op).outputArguments().toArray(new INDArray[0]); } @Override @@ -822,7 +821,7 @@ public class DefaultOpExecutioner implements OpExecutioner { } @Override - public String getString(Utf8Buffer buffer, long index) { + public String getString(DataBuffer buffer, long index) { throw new UnsupportedOperationException(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java index 1be417644..c4af57864 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java @@ -20,7 +20,6 @@ import lombok.NonNull; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArrayStatistics; import org.nd4j.linalg.api.ops.*; @@ -32,8 +31,6 @@ import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.TadPack; import org.nd4j.linalg.cache.TADManager; -import org.nd4j.linalg.primitives.Pair; -import org.nd4j.linalg.profiler.OpProfiler; import org.nd4j.linalg.profiler.ProfilerConfig; import java.util.List; @@ -411,7 +408,7 @@ public interface OpExecutioner { * @param index * @return */ - String getString(Utf8Buffer buffer, long index); + String getString(DataBuffer buffer, long index); /** * Temporary hook diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java index 3487cc216..c80f9acf1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java @@ -69,6 +69,8 @@ public class BiasAdd extends DynamicCustomOp { super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); if(attributesForNode.containsKey("data_format")){ nchw = "NCHW".equalsIgnoreCase(attributesForNode.get("data_format").getS().toStringUtf8()); + } else { + nchw = false; //TF default is NHWC } bArguments.clear(); bArguments.add(nchw); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMax.java index 252ad2dd7..35181bc23 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMax.java @@ -81,11 +81,6 @@ public class BroadcastMax extends BaseBroadcastOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - @Override - public String tensorflowName() { - return "max"; - } - @Override public List doDiff(List f1) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java index 8a7234532..c8cac0b6c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java @@ -81,11 +81,6 @@ public class BroadcastMin extends BaseBroadcastOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - @Override - public String tensorflowName() { - return "min"; - } - @Override public List doDiff(List f1) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMulOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMulOp.java index 33aa7b176..406b3f6f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMulOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMulOp.java @@ -75,11 +75,6 @@ public class BroadcastMulOp extends BaseBroadcastOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - @Override - public String tensorflowName() { - return "mul"; - } - @Override public List doDiff(List f1) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java index e060db4b6..f4e93ea22 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java @@ -75,11 +75,6 @@ public class BroadcastSubOp extends BaseBroadcastOp { throw new NoOpNameFoundException("No onnx op opName found for " + opName()); } - @Override - public String tensorflowName(){ - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - @Override public List doDiff(List f1) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeArea.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeArea.java new file mode 100644 index 000000000..8acf558e9 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeArea.java @@ -0,0 +1,114 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.image; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.val; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +@NoArgsConstructor +public class ResizeArea extends DynamicCustomOp { + protected boolean alignCorners = false; + protected Integer height = null; + protected Integer width = null; + + public ResizeArea(@NonNull SameDiff sd, @NonNull SDVariable image, int height, int width, + boolean alignCorners) { + super(sd, image); + this.alignCorners = alignCorners; + this.height = height; + this.width = width; + addArgs(); + } + + public ResizeArea(@NonNull INDArray x, INDArray z, int height, int width, + boolean alignCorners) { + super(new INDArray[]{x}, new INDArray[]{z}); + this.alignCorners = alignCorners; + this.height = height; + this.width = width; + addArgs(); + } + + @Override + public String opName() { + return "resize_area"; + } + + @Override + public String tensorflowName() { + return "ResizeArea"; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + + val attrC = attributesForNode.get("align_corners"); + this.alignCorners = attrC != null ? attrC.getB() : false; + + addArgs(); + } + + protected void addArgs() { + iArguments.clear(); + if(height != null && width != null){ + INDArray size = Nd4j.createFromArray(new int[]{height,width}); + addInputArgument(size); + //iArguments.add(Long.valueOf(height)); + //iArguments.add(Long.valueOf(width)); + } + addBArgument(alignCorners); + } + + @Override + public Map propertiesForFunction() { + Map ret = new LinkedHashMap<>(); + ret.put("alignCorners", alignCorners); + ret.put("height", height); + ret.put("width", width); + return ret; + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException(); + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2), + "Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(DataType.FLOAT); + } +} + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java index 196876cb2..c62341f28 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java @@ -86,7 +86,7 @@ public class Conv1DConfig extends BaseConvolutionConfig { ret.put("s", s); ret.put("p", p); ret.put("d", d); - ret.put("isSameMode", paddingMode); + ret.put("paddingMode", paddingMode); ret.put("dataFormat", dataFormat); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index d0c1bae38..008a065ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -16,6 +16,8 @@ package org.nd4j.linalg.api.ops.impl.reduce; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Longs; import lombok.NoArgsConstructor; @@ -325,7 +327,8 @@ public class TensorMmul extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "matmul"; + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/PowBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/PowBp.java new file mode 100644 index 000000000..c46414f79 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/PowBp.java @@ -0,0 +1,45 @@ +package org.nd4j.linalg.api.ops.impl.reduce.bp; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.BaseArithmeticBackpropOp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class PowBp extends BaseDynamicTransformOp { + + public PowBp(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable dLdz) { + super(sameDiff,new SDVariable[]{x,y,dLdz}, false); + } + + public PowBp(INDArray x, INDArray y, INDArray dLdz, + INDArray dLdx, INDArray dLdy) { + super(new INDArray[]{x,y, dLdz}, new INDArray[]{dLdx, dLdy}); + } + + @Override + public String opName() { + return "Pow_bp"; + } + + @Override + public boolean isInplaceCall() { + return false; + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got input %s", getClass(), dataTypes); + //Gradient types: same as input + return Arrays.asList(arg(0).dataType(), arg(1).dataType()); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java index 8aafce3d1..08ead2683 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java @@ -19,7 +19,9 @@ package org.nd4j.linalg.api.ops.impl.scalar; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import org.nd4j.linalg.api.ops.BaseTransformOp; @@ -29,6 +31,7 @@ import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -89,9 +92,8 @@ public class Pow extends BaseScalarOp { } @Override - public List doDiff(List i_v1) { + public List doDiff(List i_v1) { SDVariable g = f().powDerivative(arg(), this.pow).mul(i_v1.get(0)); return Arrays.asList(g); } - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java index 7e4d0fa09..0ee0c07f7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java @@ -19,8 +19,13 @@ public class RectifiedLinearDerivative extends DynamicCustomOp { super(sd, new SDVariable[]{input, gradient}); } - public RectifiedLinearDerivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + public RectifiedLinearDerivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) { + this(input, gradient, output, 0.0); + } + + public RectifiedLinearDerivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double scalar){ super(new INDArray[]{input, gradient}, wrapOrNull(output)); + addTArgument(scalar); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java index 2b5a49682..504012703 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java @@ -88,6 +88,8 @@ public class Variance extends BaseReduceOp { return 0; } + + @Override public String opName() { return "var"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java index df41438fe..e155a4f2a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java @@ -62,11 +62,14 @@ public class Pow extends DynamicCustomOp { //dL/da = b*a^(b-1) * dL/dy //dL/db = a^b * log(a) * dL/dy - SDVariable a = arg(0); + /*SDVariable a = arg(0); SDVariable b = arg(1); SDVariable dlda = b.mul(sameDiff.math().pow(a,b.sub(1))).mul(f1.get(0)); SDVariable dldb = outputVariable().mul(sameDiff.math().log(a)).mul(f1.get(0)); - return Arrays.asList(dlda, dldb); + return Arrays.asList(dlda, dldb);*/ + + SDVariable[] g = f().powBp(arg(0), arg(1), f1.get(0)); + return Arrays.asList(g); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java index 5b6833fd8..12c852949 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java @@ -61,16 +61,10 @@ public class RSubOp extends BaseDynamicTransformOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - @Override - public String tensorflowName() { - return "sub"; - } - public RSubOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } - @Override public List doDiff(List i_v) { return f().rsubBp(larg(), rarg(), i_v.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index 6d6798701..fb4dc1c80 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -61,7 +61,7 @@ public class UnsortedSegmentMax extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index f51b94218..78774d3da 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -61,7 +61,7 @@ public class UnsortedSegmentMean extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 1b885676e..cc97c3ddb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -61,7 +61,7 @@ public class UnsortedSegmentMin extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index b2e254fb7..4f18b4cec 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -61,7 +61,7 @@ public class UnsortedSegmentProd extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java index ef34e9f81..e995ec427 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java @@ -60,7 +60,7 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); List out = new ArrayList<>(); for( int i=0; i calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); //TODO Allow customizing output type return Collections.singletonList(Nd4j.defaultFloatingPointType()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java index 682d7c230..4b104d08b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java @@ -63,10 +63,15 @@ public class DistributionUniform extends DynamicCustomOp { addArgs(); } - public DistributionUniform(INDArray shape, INDArray out, double min, double max){ + public DistributionUniform(INDArray shape, INDArray out, double min, double max) { + this(shape, out, min, max, null); + } + + public DistributionUniform(INDArray shape, INDArray out, double min, double max, DataType dataType){ super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List)null); this.min = min; this.max = max; + this.dataType = dataType; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomGamma.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomGamma.java new file mode 100644 index 000000000..d7ad376f0 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomGamma.java @@ -0,0 +1,84 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.random.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +@NoArgsConstructor +public class RandomGamma extends DynamicCustomOp { + + public RandomGamma(@NonNull INDArray shape, @NonNull INDArray alpha, INDArray beta, + int... seeds) { + if (beta != null) { + addInputArgument(shape,alpha,beta); + } + addInputArgument(shape,alpha); + addIArgument(seeds); + } + + public RandomGamma(@NonNull INDArray shape, @NonNull INDArray alpha, INDArray beta) { + + this(shape,alpha,beta,0,0); + } + + public RandomGamma(@NonNull SameDiff sameDiff, @NonNull SDVariable shape, + @NonNull SDVariable alpha, SDVariable beta, int... seeds) { + super(null, sameDiff, beta != null ? new SDVariable[]{shape, alpha, beta} : + new SDVariable[]{shape, alpha}); + addIArgument(seeds); + } + + @Override + public String opName() { + return "random_gamma"; + } + + @Override + public String tensorflowName() { + return "RandomGamma"; + } + + private DataType outputDataType = DataType.FLOAT; + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + if(attributesForNode.containsKey("alpha")) { + outputDataType = DataTypeAdapter.dtypeConv(attributesForNode.get("alpha").getType()); + } + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null, "Expected exactly input datatypes for %s, got null", getClass()); + return Collections.singletonList(outputDataType); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomPoisson.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomPoisson.java new file mode 100644 index 000000000..b407ca47d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomPoisson.java @@ -0,0 +1,79 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.random.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.rng.Random; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +@NoArgsConstructor +public class RandomPoisson extends DynamicCustomOp { + + private DataType outputDataType = DataType.FLOAT; + + public RandomPoisson(@NonNull INDArray shape, @NonNull INDArray rate, int... seeds) { + addInputArgument(shape, rate); + addIArgument(seeds); + } + + public RandomPoisson(@NonNull INDArray shape, @NonNull INDArray rate) { + this(shape, rate, 0,0); + } + + public RandomPoisson(@NonNull SameDiff sameDiff, @NonNull SDVariable shape, @NonNull SDVariable rate, int... seeds) { + super(null, sameDiff, new SDVariable[]{shape, rate}); + addIArgument(seeds); + } + + @Override + public String opName() { + return "random_poisson"; + } + + @Override + public String[] tensorflowNames() { + return new String[]{"RandomPoisson","RandomPoissonV2"}; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + if(attributesForNode.containsKey("dtype")) { + outputDataType = DataTypeAdapter.dtypeConv(attributesForNode.get("dtype").getType()); + } + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", + getClass(), inputDataTypes.size()); + return Collections.singletonList(outputDataType); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomShuffle.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomShuffle.java new file mode 100644 index 000000000..b08970f11 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomShuffle.java @@ -0,0 +1,63 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.random.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class RandomShuffle extends DynamicCustomOp { + + public RandomShuffle(@NonNull INDArray value, int... seeds) { + addInputArgument(value); + addIArgument(seeds); + } + + public RandomShuffle(@NonNull INDArray value) { + this(value, 0, 0); + } + + public RandomShuffle(@NonNull SameDiff sameDiff, @NonNull SDVariable value, int...seeds) { + super(null, sameDiff, new SDVariable[]{value}); + addIArgument(seeds); + } + + @Override + public String opName() { + return "random_shuffle"; + } + + @Override + public String tensorflowName() { + return "RandomShuffle"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintAffinity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintAffinity.java new file mode 100644 index 000000000..d21e55916 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintAffinity.java @@ -0,0 +1,43 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.util; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +/** + * This is a wrapper for PrintAffinity op that just prints out affinity & locality status of INDArray + * + * @author raver119@gmail.com + */ +public class PrintAffinity extends DynamicCustomOp { + + public PrintAffinity() { + // + } + + public PrintAffinity(INDArray array) { + inputArguments.add(array); + } + + @Override + public String opName() { + return "print_affinity"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintVariable.java new file mode 100644 index 000000000..abbf88f15 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintVariable.java @@ -0,0 +1,66 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.util; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +/** + * This is a wrapper for PrintVariable op that just prints out Variable to the stdout + * + * @author raver119@gmail.com + */ +public class PrintVariable extends DynamicCustomOp { + + public PrintVariable() { + // + } + + public PrintVariable(INDArray array, boolean printSpecial) { + inputArguments.add(array); + bArguments.add(printSpecial); + } + + public PrintVariable(INDArray array) { + this(array, false); + } + + public PrintVariable(INDArray array, String message, boolean printSpecial) { + this(array, Nd4j.create(message), printSpecial); + } + + public PrintVariable(INDArray array, String message) { + this(array, Nd4j.create(message), false); + } + + public PrintVariable(INDArray array, INDArray message, boolean printSpecial) { + this(array, printSpecial); + Preconditions.checkArgument(message.isS(), "Message argument should have String data type, but got [" + message.dataType() +"] instead"); + inputArguments.add(message); + } + + public PrintVariable(INDArray array, INDArray message) { + this(array, message, false); + } + + @Override + public String opName() { + return "print_variable"; + } +} diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/BaseDataBuffer.java similarity index 61% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/BaseDataBuffer.java index 15249acc9..5ccbc54cc 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/BaseDataBuffer.java @@ -81,12 +81,10 @@ public abstract class BaseDataBuffer implements DataBuffer { protected transient DataBuffer wrappedDataBuffer; protected transient long workspaceGenerationId = 0L; - //protected Collection referencing = Collections.synchronizedSet(new HashSet()); - //protected boolean isPersist = false; protected AllocationMode allocationMode; - protected transient Pointer pointer; - protected transient Indexer indexer; - //protected AtomicBoolean dirty = new AtomicBoolean(false); + + protected transient Indexer indexer = null; + protected transient Pointer pointer = null; protected transient boolean attached = false; protected transient MemoryWorkspace parentWorkspace; @@ -94,7 +92,6 @@ public abstract class BaseDataBuffer implements DataBuffer { // Allocator-related stuff. Moved down here to avoid opType casting. protected transient DataBuffer originalBuffer; protected transient long originalOffset = 0; - protected transient Long trackingPoint; protected transient boolean constant = false; protected transient boolean released = false; @@ -203,7 +200,6 @@ public abstract class BaseDataBuffer implements DataBuffer { this.originalOffset = offset; // + underlyingBuffer.originalOffset(); } - pointer = underlyingBuffer.pointer(); setIndexer(underlyingBuffer.indexer()); } @@ -217,378 +213,6 @@ public abstract class BaseDataBuffer implements DataBuffer { return originalBuffer; } - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(float[] data, boolean copy, long offset) { - this(data, copy); - this.offset = offset; - this.originalOffset = offset; - this.length = data.length - offset; - this.underlyingLength = data.length; - - } - - public BaseDataBuffer(float[] data, boolean copy, long offset, MemoryWorkspace workspace) { - this(data, copy, workspace); - this.offset = offset; - this.originalOffset = offset; - this.length = data.length - offset; - this.underlyingLength = data.length; - - } - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(float[] data, boolean copy) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - initTypeAndSize(); - - pointer = new FloatPointer(data); - - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - //wrappedBuffer = pointer.asByteBuffer(); - - length = data.length; - underlyingLength = data.length; - } - - public BaseDataBuffer(float[] data, boolean copy, MemoryWorkspace workspace) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - length = data.length; - underlyingLength = data.length; - attached = true; - parentWorkspace = workspace; - - initTypeAndSize(); - - //log.info("Allocating FloatPointer from array of {} elements", data.length); - - pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asFloatPointer().put(data); - workspaceGenerationId = workspace.getGenerationId(); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - //wrappedBuffer = pointer.asByteBuffer(); - } - - public BaseDataBuffer(double[] data, boolean copy, MemoryWorkspace workspace) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - length = data.length; - underlyingLength = data.length; - attached = true; - parentWorkspace = workspace; - - initTypeAndSize(); - - //log.info("Allocating FloatPointer from array of {} elements", data.length); - - pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asDoublePointer().put(data); - workspaceGenerationId = workspace.getGenerationId(); - indexer = DoubleIndexer.create((DoublePointer) pointer); - //wrappedBuffer = pointer.asByteBuffer(); - } - - - public BaseDataBuffer(int[] data, boolean copy, MemoryWorkspace workspace) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - length = data.length; - underlyingLength = data.length; - attached = true; - parentWorkspace = workspace; - - initTypeAndSize(); - - //log.info("Allocating FloatPointer from array of {} elements", data.length); - - pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asIntPointer().put(data); - workspaceGenerationId = workspace.getGenerationId(); - indexer = IntIndexer.create((IntPointer) pointer); - //wrappedBuffer = pointer.asByteBuffer(); - } - - public BaseDataBuffer(long[] data, boolean copy, MemoryWorkspace workspace) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - length = data.length; - underlyingLength = data.length; - attached = true; - parentWorkspace = workspace; - - initTypeAndSize(); - - //log.info("Allocating FloatPointer from array of {} elements", data.length); - - pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asLongPointer().put(data); - workspaceGenerationId = workspace.getGenerationId(); - indexer = LongIndexer.create((LongPointer) pointer); - //wrappedBuffer = pointer.asByteBuffer(); - } - - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(double[] data, boolean copy, long offset) { - this(data, copy); - this.offset = offset; - this.originalOffset = offset; - this.underlyingLength = data.length; - this.length = underlyingLength - offset; - } - - public BaseDataBuffer(double[] data, boolean copy, long offset, MemoryWorkspace workspace) { - this(data, copy, workspace); - this.offset = offset; - this.originalOffset = offset; - this.underlyingLength = data.length; - this.length = underlyingLength - offset; - } - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(double[] data, boolean copy) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - initTypeAndSize(); - - pointer = new DoublePointer(data); - indexer = DoubleIndexer.create((DoublePointer) pointer); - //wrappedBuffer = pointer.asByteBuffer(); - - length = data.length; - underlyingLength = data.length; - } - - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(int[] data, boolean copy, long offset) { - this(data, copy); - this.offset = offset; - this.originalOffset = offset; - this.length = data.length - offset; - this.underlyingLength = data.length; - } - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(int[] data, boolean copy) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - initTypeAndSize(); - - pointer = new IntPointer(data); - setIndexer(IntIndexer.create((IntPointer) pointer)); - - length = data.length; - underlyingLength = data.length; - - // // log.info("Creating new buffer of size: {}; dtype: {}; B", data.length, dataType()); - } - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(long[] data, boolean copy) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - initTypeAndSize(); - - pointer = new LongPointer(data); - setIndexer(LongIndexer.create((LongPointer) pointer)); - - length = data.length; - underlyingLength = data.length; - } - - /** - * - * @param data - */ - public BaseDataBuffer(double[] data) { - this(data, true); - } - - /** - * - * @param data - */ - public BaseDataBuffer(int[] data) { - this(data, true); - } - - /** - * - * @param data - */ - public BaseDataBuffer(float[] data) { - this(data, true); - } - - public BaseDataBuffer(float[] data, MemoryWorkspace workspace) { - this(data, true, workspace); - } - - /** - * - * @param length - * @param elementSize - */ - public BaseDataBuffer(int length, int elementSize, long offset) { - this(length, elementSize); - this.offset = offset; - this.originalOffset = offset; - this.length = length - offset; - this.underlyingLength = length; - } - - /** - * - * @param length - * @param elementSize - */ - public BaseDataBuffer(long length, int elementSize) { - if (length < 1) - throw new IllegalArgumentException("Length must be >= 1"); - initTypeAndSize(); - allocationMode = AllocUtil.getAllocationModeFromContext(); - this.length = length; - this.underlyingLength = length; - this.elementSize = (byte) elementSize; - - if (dataType() == DataType.DOUBLE) { - pointer = new DoublePointer(length); - indexer = DoubleIndexer.create((DoublePointer) pointer); - } else if (dataType() == DataType.FLOAT) { - pointer = new FloatPointer(length); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - } else if (dataType() == DataType.INT) { - pointer = new IntPointer(length); - setIndexer(IntIndexer.create((IntPointer) pointer)); - } else if (dataType() == DataType.LONG) { - pointer = new LongPointer(length); - setIndexer(LongIndexer.create((LongPointer) pointer)); - } else if (dataType() == DataType.SHORT) { - pointer = new ShortPointer(length); - setIndexer(ShortIndexer.create((ShortPointer) pointer)); - } else if (dataType() == DataType.BYTE) { - pointer = new BytePointer(length); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - } else if (dataType() == DataType.UBYTE) { - pointer = new BytePointer(length); - setIndexer(UByteIndexer.create((BytePointer) pointer)); - } else if (dataType() == DataType.UTF8) { - pointer = new LongPointer(length); - setIndexer(LongIndexer.create((LongPointer) pointer)); - } - - // log.info("Creating new buffer of size: {}; dtype: {}; C", length, dataType()); - } - - /** - * Create a data buffer from - * the given length - * - * @param buffer - * @param length - */ - public BaseDataBuffer(ByteBuffer buffer, long length, long offset) { - this(buffer, length); - this.offset = offset; - this.originalOffset = offset; - this.underlyingLength = length; - this.length = length - offset; - - } - - /** - * Create a data buffer from - * the given length - * - * @param buffer - * @param length - */ - public BaseDataBuffer(ByteBuffer buffer, long length) { - if (length < 1) - throw new IllegalArgumentException("Length must be >= 1"); - initTypeAndSize(); - - this.length = length; - allocationMode = AllocUtil.getAllocationModeFromContext(); - - switch (dataType()){ - case DOUBLE: - pointer = new DoublePointer(buffer.asDoubleBuffer()); - setIndexer(DoubleIndexer.create((DoublePointer) pointer)); - break; - case FLOAT: - pointer = new FloatPointer(buffer.asFloatBuffer()); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - break; - case HALF: - pointer = new ShortPointer(buffer.asShortBuffer()); - setIndexer(HalfIndexer.create((ShortPointer) pointer)); - break; - case LONG: - pointer = new LongPointer(buffer.asLongBuffer()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - break; - case INT: - pointer = new IntPointer(buffer.asIntBuffer()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - break; - case SHORT: - pointer = new ShortPointer(buffer.asShortBuffer()); - setIndexer(ShortIndexer.create((ShortPointer) pointer)); - break; - case UBYTE: //Fall through - case BYTE: - pointer = new BytePointer(buffer); - setIndexer(UByteIndexer.create((BytePointer)pointer)); - break; - case BOOL: - pointer = new BooleanPointer(length()); - setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); - break; - case UTF8: - pointer = new BytePointer(length()); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - break; - case BFLOAT16: - pointer = new ShortPointer(length()); - setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); - break; - case UINT16: - pointer = new ShortPointer(length()); - setIndexer(UShortIndexer.create((ShortPointer) pointer)); - break; - case UINT32: - pointer = new IntPointer(length()); - // FIXME: we need unsigned indexer here - setIndexer(IntIndexer.create((IntPointer) pointer)); - break; - case UINT64: - pointer = new LongPointer(length()); - // FIXME: we need unsigned indexer here - setIndexer(LongIndexer.create((LongPointer) pointer)); - break; - } - -// log.info("Creating new buffer of size: {}; dtype: {}; D", length, dataType()); - } //sets the nio wrapped buffer (allows to be overridden for other use cases like cuda) protected void setNioBuffer() { @@ -598,17 +222,6 @@ public abstract class BaseDataBuffer implements DataBuffer { } - - /** - * - * @param data - * @param length - */ - public BaseDataBuffer(byte[] data, long length) { - this(ByteBuffer.wrap(data), length); - } - - /** * Returns the indexer for the buffer * @@ -662,7 +275,6 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override @Deprecated public void persist() { - //isPersist = true; throw new UnsupportedOperationException(); } @@ -678,230 +290,10 @@ public abstract class BaseDataBuffer implements DataBuffer { throw new UnsupportedOperationException(); } - private void fillPointerWithZero() { + protected void fillPointerWithZero() { Pointer.memset(this.pointer(), 0, getElementSize() * length()); } - /** - * Instantiate a buffer with the given length - * - * @param length the length of the buffer - */ - protected BaseDataBuffer(long length) { - this(length, true); - } - - protected BaseDataBuffer(long length, boolean initialize) { - if (length < 0) - throw new IllegalArgumentException("Length must be >= 0"); - initTypeAndSize(); - this.length = length; - this.underlyingLength = length; - allocationMode = AllocUtil.getAllocationModeFromContext(); - if (length < 0) - throw new IllegalArgumentException("Unable to create a buffer of length <= 0"); - - if (dataType() == DataType.DOUBLE) { - pointer = new DoublePointer(length()); - indexer = DoubleIndexer.create((DoublePointer) pointer); - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.FLOAT) { - pointer = new FloatPointer(length()); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - - } else if (dataType() == DataType.HALF) { - pointer = new ShortPointer(length()); - setIndexer(HalfIndexer.create((ShortPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.BFLOAT16) { - pointer = new ShortPointer(length()); - setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.INT) { - pointer = new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.LONG) { - pointer = new LongPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.BYTE) { - pointer = new BytePointer(length()); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.SHORT) { - pointer = new ShortPointer(length()); - setIndexer(ShortIndexer.create((ShortPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.UBYTE) { - pointer = new BytePointer(length()); - setIndexer(UByteIndexer.create((BytePointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.UINT16) { - pointer = new ShortPointer(length()); - setIndexer(UShortIndexer.create((ShortPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.UINT32) { - pointer = new IntPointer(length()); - // FIXME: we need unsigned indexer here - setIndexer(IntIndexer.create((IntPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.UINT64) { - pointer = new LongPointer(length()); - // FIXME: we need unsigned indexer here - setIndexer(LongIndexer.create((LongPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.BOOL) { - pointer = new BooleanPointer(length()); - setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.UTF8) { - pointer = new BytePointer(length()); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } - - //// log.info("Creating new buffer of size: {}; dtype: {}; A", length, dataType()); - } - - protected BaseDataBuffer(long length, boolean initialize, MemoryWorkspace workspace) { - if (length < 1) - throw new IllegalArgumentException("Length must be >= 1"); - initTypeAndSize(); - this.length = length; - this.underlyingLength = length; - allocationMode = AllocUtil.getAllocationModeFromContext(); - - - - if (length < 0) - throw new IllegalArgumentException("Unable to create a buffer of length <= 0"); - - if (dataType() == DataType.DOUBLE) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asDoublePointer(); //new DoublePointer(length()); - indexer = DoubleIndexer.create((DoublePointer) pointer); - - } else if (dataType() == DataType.FLOAT) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asFloatPointer(); //new FloatPointer(length()); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - - } else if (dataType() == DataType.HALF) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new FloatPointer(length()); - setIndexer(HalfIndexer.create((ShortPointer) pointer)); - - } else if (dataType() == DataType.BFLOAT16) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new FloatPointer(length()); - setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); - } else if (dataType() == DataType.INT) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - - } else if (dataType() == DataType.UINT32) { - attached = true; - parentWorkspace = workspace; - - // FIXME: need unsigned indexer here - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - - } else if (dataType() == DataType.UINT64) { - attached = true; - parentWorkspace = workspace; - - // FIXME: need unsigned indexer here - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new IntPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - - } else if (dataType() == DataType.LONG) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new LongPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - } else if (dataType() == DataType.BYTE) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBytePointer(); //new LongPointer(length()); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - } else if (dataType() == DataType.UBYTE) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBytePointer(); //new LongPointer(length()); - setIndexer(UByteIndexer.create((BytePointer) pointer)); - } else if (dataType() == DataType.UINT16) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new IntPointer(length()); - setIndexer(UShortIndexer.create((ShortPointer) pointer)); - - } else if (dataType() == DataType.SHORT) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new LongPointer(length()); - setIndexer(ShortIndexer.create((ShortPointer) pointer)); - } else if (dataType() == DataType.BOOL) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBoolPointer(); //new LongPointer(length()); - setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); - } else if (dataType() == DataType.UTF8) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new LongPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - } - - workspaceGenerationId = workspace.getGenerationId(); - - } @Override public void copyAtStride(DataBuffer buf, long n, long stride, long yStride, long offset, long yOffset) { @@ -930,6 +322,9 @@ public abstract class BaseDataBuffer implements DataBuffer { //return referencing; } + public abstract Pointer addressPointer(); + + /* @Override public Pointer addressPointer() { if (released) @@ -937,7 +332,8 @@ public abstract class BaseDataBuffer implements DataBuffer { if (offset() > 0) { Pointer ret; - final long retAddress = pointer().address() + getElementSize() * offset(); + // offset is accounted at native side + final long retAddress = pointer().address(); // directly set address at construction since Pointer.address has not setter. if (dataType() == DataType.DOUBLE) { ret = new DoublePointer(pointer()) { @@ -976,13 +372,14 @@ public abstract class BaseDataBuffer implements DataBuffer { } return pointer(); } + */ @Override public long address() { if (released) throw new IllegalStateException("You can't use DataBuffer once it was released"); - return pointer().address() + getElementSize() * offset(); + return pointer().address(); } @Override @@ -1273,7 +670,7 @@ public abstract class BaseDataBuffer implements DataBuffer { try { UByteIndexer u = (UByteIndexer) indexer; for (int i = 0; i < length(); i++) { - dos.writeByte(u.get(offset() + i)); + dos.writeByte(u.get(i)); } } catch (IOException e) { throw new RuntimeException(e); @@ -1431,29 +828,29 @@ public abstract class BaseDataBuffer implements DataBuffer { } switch (dataType()) { case FLOAT: - return ((FloatIndexer) indexer).get(offset() + i); + return ((FloatIndexer) indexer).get(i); case UINT32: case INT: - return ((IntIndexer) indexer).get(offset() + i); + return ((IntIndexer) indexer).get(i); case BFLOAT16: - return ((Bfloat16Indexer) indexer).get(offset() + i); + return ((Bfloat16Indexer) indexer).get(i); case HALF: - return ((HalfIndexer) indexer).get(offset() + i); + return ((HalfIndexer) indexer).get(i); case UINT16: - return ((UShortIndexer) indexer).get(offset() + i); + return ((UShortIndexer) indexer).get(i); case SHORT: - return ((ShortIndexer) indexer).get(offset() + i); + return ((ShortIndexer) indexer).get(i); case UINT64: case LONG: - return ((LongIndexer) indexer).get(offset() + i); + return ((LongIndexer) indexer).get(i); case BOOL: - return ((BooleanIndexer) indexer).get(offset() + i) ? 1.0 : 0.0; + return ((BooleanIndexer) indexer).get(i) ? 1.0 : 0.0; case DOUBLE: - return ((DoubleIndexer) indexer).get(offset() + i); + return ((DoubleIndexer) indexer).get(i); case BYTE: - return ((ByteIndexer) indexer).get(offset() + i); + return ((ByteIndexer) indexer).get(i); case UBYTE: - return ((UByteIndexer) indexer).get(offset() + i); + return ((UByteIndexer) indexer).get(i); default: throw new UnsupportedOperationException("Cannot get double value from buffer of type " + dataType()); } @@ -1466,29 +863,29 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case FLOAT: - return (long) ((FloatIndexer) indexer).get(offset() + i); + return (long) ((FloatIndexer) indexer).get(i); case DOUBLE: - return (long) ((DoubleIndexer) indexer).get(offset() + i); + return (long) ((DoubleIndexer) indexer).get(i); case BFLOAT16: - return (long) ((Bfloat16Indexer) indexer).get(offset() + i); + return (long) ((Bfloat16Indexer) indexer).get(i); case HALF: - return (long) ((HalfIndexer) indexer).get(offset() + i); + return (long) ((HalfIndexer) indexer).get( i); case UINT64: case LONG: - return ((LongIndexer) indexer).get(offset() + i); + return ((LongIndexer) indexer).get(i); case UINT32: case INT: - return (long) ((IntIndexer) indexer).get(offset() + i); + return (long) ((IntIndexer) indexer).get(i); case UINT16: - return (long) ((UShortIndexer) indexer).get(offset() + i); + return (long) ((UShortIndexer) indexer).get(i); case SHORT: - return (long) ((ShortIndexer) indexer).get(offset() + i); + return (long) ((ShortIndexer) indexer).get(i); case BYTE: - return (long) ((ByteIndexer) indexer).get(offset() + i); + return (long) ((ByteIndexer) indexer).get(i); case UBYTE: - return (long) ((UByteIndexer) indexer).get(offset() + i); + return (long) ((UByteIndexer) indexer).get(i); case BOOL: - return ((BooleanIndexer) indexer).get(offset() + i) ? 1L : 0L; + return ((BooleanIndexer) indexer).get(i) ? 1L : 0L; default: throw new UnsupportedOperationException("Cannot get long value from buffer of type " + dataType()); } @@ -1505,26 +902,26 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case DOUBLE: - return (short) ((DoubleIndexer) indexer).get(offset() + i); + return (short) ((DoubleIndexer) indexer).get(i); case BFLOAT16: - return (short) ((Bfloat16Indexer) indexer).get(offset() + i); + return (short) ((Bfloat16Indexer) indexer).get(i); case HALF: - return (short) ((HalfIndexer) indexer).get(offset() + i); + return (short) ((HalfIndexer) indexer).get(i); case BOOL: - return (short) (((BooleanIndexer) indexer).get(offset() + i) ? 1 : 0); + return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0); case UINT32: case INT: - return (short) ((IntIndexer) indexer).get(offset() + i); + return (short) ((IntIndexer) indexer).get(i); case UINT16: case SHORT: - return ((ShortIndexer) indexer).get(offset() + i); + return ((ShortIndexer) indexer).get(i); case BYTE: - return (short) ((ByteIndexer) indexer).get(offset() + i); + return (short) ((ByteIndexer) indexer).get(i); case UINT64: case LONG: - return (short) ((LongIndexer) indexer).get(offset() + i); + return (short) ((LongIndexer) indexer).get(i); case FLOAT: - return (short) ((FloatIndexer) indexer).get(offset() + i); + return (short) ((FloatIndexer) indexer).get(i); default: throw new UnsupportedOperationException("Cannot get short value from buffer of type " + dataType()); } @@ -1546,29 +943,29 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case DOUBLE: - return (float) ((DoubleIndexer) indexer).get(offset() + i); + return (float) ((DoubleIndexer) indexer).get(i); case BOOL: - return ((BooleanIndexer) indexer).get(offset() + i) ? 1.f : 0.f; + return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f; case UINT32: case INT: - return (float) ((IntIndexer) indexer).get(offset() + i); + return (float) ((IntIndexer) indexer).get(i); case UINT16: - return ((UShortIndexer) indexer).get(offset() + i); + return ((UShortIndexer) indexer).get(i); case SHORT: - return (float) ((ShortIndexer) indexer).get(offset() + i); + return (float) ((ShortIndexer) indexer).get(i); case BFLOAT16: - return ((Bfloat16Indexer) indexer).get(offset() + i); + return ((Bfloat16Indexer) indexer).get(i); case HALF: - return ((HalfIndexer) indexer).get(offset() + i); + return ((HalfIndexer) indexer).get(i); case UBYTE: - return (float) ((UByteIndexer) indexer).get(offset() + i); + return (float) ((UByteIndexer) indexer).get(i); case BYTE: - return (float) ((ByteIndexer) indexer).get(offset() + i); + return (float) ((ByteIndexer) indexer).get(i); case UINT64: case LONG: - return (float) ((LongIndexer) indexer).get(offset() + i); + return (float) ((LongIndexer) indexer).get(i); case FLOAT: - return ((FloatIndexer) indexer).get(offset() + i); + return ((FloatIndexer) indexer).get(i); default: throw new UnsupportedOperationException("Cannot get float value from buffer of type " + dataType()); } @@ -1581,29 +978,29 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case DOUBLE: - return (int) ((DoubleIndexer) indexer).get(offset() + i); + return (int) ((DoubleIndexer) indexer).get(i); case BOOL: - return ((BooleanIndexer) indexer).get(offset() + i) ? 1 : 0; + return ((BooleanIndexer) indexer).get(i) ? 1 : 0; case UINT32: case INT: - return ((IntIndexer) indexer).get(offset() + i); + return ((IntIndexer) indexer).get(i); case BFLOAT16: - return (int) ((Bfloat16Indexer) indexer).get(offset() + i); + return (int) ((Bfloat16Indexer) indexer).get(i); case HALF: - return (int) ((HalfIndexer) indexer).get(offset() + i); + return (int) ((HalfIndexer) indexer).get(i); case UINT16: - return ((UShortIndexer) indexer).get(offset() + i); + return ((UShortIndexer) indexer).get(i); case SHORT: - return ((ShortIndexer) indexer).get(offset() + i); + return ((ShortIndexer) indexer).get(i); case UBYTE: - return ((UByteIndexer) indexer).get(offset() + i); + return ((UByteIndexer) indexer).get(i); case BYTE: - return ((ByteIndexer) indexer).get(offset() + i); + return ((ByteIndexer) indexer).get(i); case UINT64: case LONG: - return (int) ((LongIndexer) indexer).get(offset() + i); + return (int) ((LongIndexer) indexer).get(i); case FLOAT: - return (int) ((FloatIndexer) indexer).get(offset() + i); + return (int) ((FloatIndexer) indexer).get(i); default: throw new UnsupportedOperationException("Cannot get integer value from buffer of type " + dataType()); } @@ -1623,79 +1020,7 @@ public abstract class BaseDataBuffer implements DataBuffer { return getFloat(i); } - public void pointerIndexerByCurrentType(DataType currentType) { - switch (currentType) { - case UINT64: - pointer = new LongPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - type = DataType.UINT64; - break; - case LONG: - pointer = new LongPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - type = DataType.LONG; - break; - case UINT32: - pointer = new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - type = DataType.UINT32; - break; - case INT: - pointer = new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - type = DataType.INT; - break; - case UINT16: - pointer = new ShortPointer(length()); - setIndexer(UShortIndexer.create((ShortPointer) pointer)); - type = DataType.UINT16; - break; - case SHORT: - pointer = new ShortPointer(length()); - setIndexer(ShortIndexer.create((ShortPointer) pointer)); - type = DataType.SHORT; - break; - case UBYTE: - pointer = new BytePointer(length()); - setIndexer(UByteIndexer.create((BytePointer) pointer)); - type = DataType.UBYTE; - break; - case BYTE: - pointer = new BytePointer(length()); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - type = DataType.BYTE; - break; - case DOUBLE: - pointer = new DoublePointer(length()); - indexer = DoubleIndexer.create((DoublePointer) pointer); - type = DataType.DOUBLE; - break; - case FLOAT: - pointer = new FloatPointer(length()); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - type = DataType.FLOAT; - break; - case BFLOAT16: - pointer = new ShortPointer(length()); - setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); - type = DataType.BFLOAT16; - break; - case HALF: - pointer = new ShortPointer(length()); - setIndexer(HalfIndexer.create((ShortPointer) pointer)); - type = DataType.HALF; - break; - case BOOL: - pointer = new BooleanPointer(length()); - setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); - type = DataType.BOOL; - break; - case COMPRESSED: - break; - default: - throw new UnsupportedOperationException(); - } - } + public abstract void pointerIndexerByCurrentType(DataType currentType); public void putByDestinationType(long i, Number element, DataType globalType) { if (globalType == DataType.INT || type == DataType.INT || globalType == DataType.UINT16 || globalType == DataType.UBYTE || globalType == DataType.SHORT|| globalType == DataType.BYTE || globalType == DataType.BOOL) { @@ -1722,47 +1047,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(offset() + i, element == 0.0 ? false : true); + ((BooleanIndexer) indexer).put(i, element == 0.0 ? false : true); break; case BYTE: - ((ByteIndexer) indexer).put(offset() + i, (byte) element); + ((ByteIndexer) indexer).put(i, (byte) element); break; case UBYTE: - ((UByteIndexer) indexer).put(offset() + i, (int) element); + ((UByteIndexer) indexer).put(i, (int) element); break; case UINT16: - ((UShortIndexer) indexer).put(offset() + i, (int)element); + ((UShortIndexer) indexer).put(i, (int)element); break; case SHORT: - ((ShortIndexer) indexer).put(offset() + i, (short) element); + ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: case INT: - ((IntIndexer) indexer).put(offset() + i, (int) element); + ((IntIndexer) indexer).put(i, (int) element); break; case UINT64: case LONG: - ((LongIndexer) indexer).put(offset() + i, (long) element); + ((LongIndexer) indexer).put(i, (long) element); break; case BFLOAT16: - ((Bfloat16Indexer) indexer).put(offset() + i, element); + ((Bfloat16Indexer) indexer).put(i, element); break; case HALF: - ((HalfIndexer) indexer).put(offset() + i, element); + ((HalfIndexer) indexer).put(i, element); break; case FLOAT: - ((FloatIndexer) indexer).put(offset() + i, element); + ((FloatIndexer) indexer).put(i, element); break; case DOUBLE: - ((DoubleIndexer) indexer).put(offset() + i, element); + ((DoubleIndexer) indexer).put(i, element); break; default: throw new IllegalStateException("Unsupported type: " + dataType()); } - - if (i == length) { - length++; - } } @Override @@ -1772,47 +1093,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(offset() + i, element > 0.0); + ((BooleanIndexer) indexer).put(i, element > 0.0); break; case BYTE: - ((ByteIndexer) indexer).put(offset() + i, (byte) element); + ((ByteIndexer) indexer).put(i, (byte) element); break; case UBYTE: - ((UByteIndexer) indexer).put(offset() + i, (short) element); + ((UByteIndexer) indexer).put(i, (short) element); break; case UINT16: - ((UShortIndexer) indexer).put(offset() + i, (int) element); + ((UShortIndexer) indexer).put(i, (int) element); break; case SHORT: - ((ShortIndexer) indexer).put(offset() + i, (short) element); + ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: case INT: - ((IntIndexer) indexer).put(offset() + i, (int) element); + ((IntIndexer) indexer).put(i, (int) element); break; case UINT64: case LONG: - ((LongIndexer) indexer).put(offset() + i, (long) element); + ((LongIndexer) indexer).put(i, (long) element); break; case BFLOAT16: - ((Bfloat16Indexer) indexer).put(offset() + i, (float) element); + ((Bfloat16Indexer) indexer).put(i, (float) element); break; case HALF: - ((HalfIndexer) indexer).put(offset() + i, (float) element); + ((HalfIndexer) indexer).put(i, (float) element); break; case FLOAT: - ((FloatIndexer) indexer).put(offset() + i, (float) element); + ((FloatIndexer) indexer).put(i, (float) element); break; case DOUBLE: - ((DoubleIndexer) indexer).put(offset() + i, element); + ((DoubleIndexer) indexer).put(i, element); break; default: throw new UnsupportedOperationException("Unsupported data type: " + dataType()); } - - if (i == length) { - length++; - } } @Override @@ -1822,47 +1139,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true); + ((BooleanIndexer) indexer).put(i, element == 0 ? false : true); break; case BYTE: - ((ByteIndexer) indexer).put(offset() + i, (byte) element); + ((ByteIndexer) indexer).put(i, (byte) element); break; case UBYTE: - ((UByteIndexer) indexer).put(offset() + i, element); + ((UByteIndexer) indexer).put(i, element); break; case UINT16: - ((UShortIndexer) indexer).put(offset() + i, element); + ((UShortIndexer) indexer).put(i, element); break; case SHORT: - ((ShortIndexer) indexer).put(offset() + i, (short) element); + ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: case INT: - ((IntIndexer) indexer).put(offset() + i, element); + ((IntIndexer) indexer).put(i, element); break; case UINT64: case LONG: - ((LongIndexer) indexer).put(offset() + i, element); + ((LongIndexer) indexer).put(i, element); break; case BFLOAT16: - ((Bfloat16Indexer) indexer).put(offset() + i, element); + ((Bfloat16Indexer) indexer).put(i, element); break; case HALF: - ((HalfIndexer) indexer).put(offset() + i, element); + ((HalfIndexer) indexer).put(i, element); break; case FLOAT: - ((FloatIndexer) indexer).put(offset() + i, element); + ((FloatIndexer) indexer).put(i, element); break; case DOUBLE: - ((DoubleIndexer) indexer).put(offset() + i, element); + ((DoubleIndexer) indexer).put(i, element); break; default: throw new UnsupportedOperationException("Unsupported data type: " + dataType()); } - - if (i == length) { - length++; - } } @Override @@ -1872,47 +1185,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(offset() + i, element); + ((BooleanIndexer) indexer).put(i, element); break; case BYTE: - ((ByteIndexer) indexer).put(offset() + i, element ? (byte)1 : (byte) 0); + ((ByteIndexer) indexer).put(i, element ? (byte)1 : (byte) 0); break; case UBYTE: - ((UByteIndexer) indexer).put(offset() + i, element ? (byte)1 : (byte) 0); + ((UByteIndexer) indexer).put(i, element ? (byte)1 : (byte) 0); break; case UINT16: - ((UShortIndexer) indexer).put(offset() + i, element ? 1 : 0); + ((UShortIndexer) indexer).put(i, element ? 1 : 0); break; case SHORT: - ((ShortIndexer) indexer).put(offset() + i, element ? (short) 1 : (short) 0); + ((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0); break; case INT: case UINT32: - ((IntIndexer) indexer).put(offset() + i, element ? 1 : 0); + ((IntIndexer) indexer).put(i, element ? 1 : 0); break; case UINT64: case LONG: - ((LongIndexer) indexer).put(offset() + i, element ? 1 : 0); + ((LongIndexer) indexer).put(i, element ? 1 : 0); break; case BFLOAT16: - ((Bfloat16Indexer) indexer).put(offset() + i, element ? 1.0f : 0.0f); + ((Bfloat16Indexer) indexer).put(i, element ? 1.0f : 0.0f); break; case HALF: - ((HalfIndexer) indexer).put(offset() + i, element ? 1.0f : 0.0f); + ((HalfIndexer) indexer).put(i, element ? 1.0f : 0.0f); break; case FLOAT: - ((FloatIndexer) indexer).put(offset() + i, element ? 1.0f : 0.0f); + ((FloatIndexer) indexer).put(i, element ? 1.0f : 0.0f); break; case DOUBLE: - ((DoubleIndexer) indexer).put(offset() + i, element ? 1.0 : 0.0); + ((DoubleIndexer) indexer).put(i, element ? 1.0 : 0.0); break; default: throw new UnsupportedOperationException("Unsupported data type: " + dataType()); } - - if (i == length) { - length++; - } } @Override @@ -1922,47 +1231,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true); + ((BooleanIndexer) indexer).put(i, element == 0 ? false : true); break; case BYTE: - ((ByteIndexer) indexer).put(offset() + i, (byte) element); + ((ByteIndexer) indexer).put(i, (byte) element); break; case UBYTE: - ((UByteIndexer) indexer).put(offset() + i, (short) element); + ((UByteIndexer) indexer).put(i, (short) element); break; case UINT16: - ((UShortIndexer) indexer).put(offset() + i, (int) element); + ((UShortIndexer) indexer).put(i, (int) element); break; case SHORT: - ((ShortIndexer) indexer).put(offset() + i, (short) element); + ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: case INT: - ((IntIndexer) indexer).put(offset() + i, (int) element); + ((IntIndexer) indexer).put(i, (int) element); break; case UINT64: case LONG: - ((LongIndexer) indexer).put(offset() + i, element); + ((LongIndexer) indexer).put(i, element); break; case BFLOAT16: - ((Bfloat16Indexer) indexer).put(offset() + i, (float) element); + ((Bfloat16Indexer) indexer).put(i, (float) element); break; case HALF: - ((HalfIndexer) indexer).put(offset() + i, (float) element); + ((HalfIndexer) indexer).put(i, (float) element); break; case FLOAT: - ((FloatIndexer) indexer).put(offset() + i, (float) element); + ((FloatIndexer) indexer).put(i, (float) element); break; case DOUBLE: - ((DoubleIndexer) indexer).put(offset() + i, (double) element); + ((DoubleIndexer) indexer).put(i, (double) element); break; default: throw new UnsupportedOperationException("Unsupported data type: " + dataType()); } - - if (i == length) { - length++; - } } @Override @@ -2507,31 +1812,6 @@ public abstract class BaseDataBuffer implements DataBuffer { return originalOffset; } - /** - * Returns tracking point for Allocator - * - * PLEASE NOTE: Suitable & meaningful only for specific backends - * - * @return - */ - @Override - public Long getTrackingPoint() { - if (underlyingDataBuffer() != this) - return underlyingDataBuffer() == null ? trackingPoint : underlyingDataBuffer().getTrackingPoint(); - return trackingPoint; - } - - /** - * Sets tracking point used by Allocator - * - * PLEASE NOTE: Suitable & meaningful only for specific backends - * - * @param trackingPoint - */ - public void setTrackingPoint(Long trackingPoint) { - this.trackingPoint = trackingPoint; - } - /** * This method returns whether this DataBuffer is constant, or not. * Constant buffer means that it modified only during creation time, and then it stays the same for all lifecycle. I.e. used in shape info databuffers. @@ -2595,63 +1875,7 @@ public abstract class BaseDataBuffer implements DataBuffer { return null; } - /** - * Reallocate the native memory of the buffer - * @param length the new length of the buffer - * @return this databuffer - * */ - @Override - public DataBuffer reallocate(long length) { - - Pointer oldPointer = pointer; - if (isAttached()) { - long capacity = length * getElementSize(); - switch (dataType()) { - case DOUBLE: - pointer = getParentWorkspace().alloc(capacity, DataType.DOUBLE, false).asDoublePointer(); - indexer = DoubleIndexer.create((DoublePointer) pointer); - break; - case FLOAT: - pointer = getParentWorkspace().alloc(capacity, DataType.FLOAT, false).asFloatPointer(); - indexer = FloatIndexer.create((FloatPointer) pointer); - break; - case INT: - pointer = getParentWorkspace().alloc(capacity, DataType.INT, false).asIntPointer(); - indexer = IntIndexer.create((IntPointer) pointer); - break; - case LONG: - pointer = getParentWorkspace().alloc(capacity, DataType.LONG, false).asLongPointer(); - indexer = LongIndexer.create((LongPointer) pointer); - break; - } - - workspaceGenerationId = getParentWorkspace().getGenerationId(); - } else { - switch (dataType()) { - case INT: - pointer = new IntPointer(length); - indexer = IntIndexer.create((IntPointer) pointer); - break; - case DOUBLE: - pointer = new DoublePointer(length); - indexer = DoubleIndexer.create((DoublePointer) pointer); - break; - case FLOAT: - pointer = new FloatPointer(length); - indexer = FloatIndexer.create((FloatPointer) pointer); - break; - case LONG: - pointer = new LongPointer(length); - indexer = LongIndexer.create((LongPointer) pointer); - break; - } - } - - Pointer.memcpy(pointer, oldPointer, this.length() * getElementSize()); - this.underlyingLength = length; - this.length = length; - return this; - } + public abstract DataBuffer reallocate(long length); /** * @return the capacity of the buffer diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataBuffer.java similarity index 97% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataBuffer.java index 9b1c2ecec..303f6383d 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataBuffer.java @@ -622,23 +622,6 @@ public interface DataBuffer extends Serializable, AutoCloseable { */ void read(InputStream is, AllocationMode allocationMode, long length, DataType dataType); - /** - * Returns tracking point for Allocator - * - * PLEASE NOTE: Suitable & meaningful only for specific backends - * @return - */ - Long getTrackingPoint(); - - /** - * Sets tracking point used by Allocator - * - * PLEASE NOTE: Suitable & meaningful only for specific backends - * - * @param trackingPoint - */ - void setTrackingPoint(Long trackingPoint); - /** * This method returns whether this DataBuffer is constant, or not. * Constant buffer means that it modified only during creation time, and then it stays the same for all lifecycle. I.e. used in shape info databuffers. diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataType.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataType.java index 84715f878..7555bce21 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataType.java @@ -17,14 +17,28 @@ package org.nd4j.linalg.api.buffer; public enum DataType { + DOUBLE, FLOAT, + + @Deprecated HALF, + + @Deprecated LONG, + + @Deprecated INT, + + @Deprecated SHORT, + + @Deprecated UBYTE, + + @Deprecated BYTE, + BOOL, UTF8, COMPRESSED, @@ -34,6 +48,13 @@ public enum DataType { UINT64, UNKNOWN; + public static final DataType FLOAT16 = DataType.HALF; + public static final DataType INT32 = DataType.INT; + public static final DataType INT64 = DataType.LONG; + public static final DataType INT16 = DataType.SHORT; + public static final DataType INT8 = DataType.BYTE; + public static final DataType UINT8 = DataType.UBYTE; + public static DataType fromInt(int type) { switch (type) { @@ -94,7 +115,7 @@ public enum DataType { * Note: Boolean values are considered numerical (0/1)
*/ public boolean isNumerical(){ - return this != UTF8 && this != COMPRESSED && this != UNKNOWN; + return this != UTF8 && this != BOOL && this != COMPRESSED && this != UNKNOWN; } /** diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataTypeEx.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataTypeEx.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataTypeEx.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataTypeEx.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/allocation/MemoryStrategy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/allocation/MemoryStrategy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/allocation/MemoryStrategy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/allocation/MemoryStrategy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/factory/DataBufferFactory.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/factory/DataBufferFactory.java index 743f34655..abb674499 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/factory/DataBufferFactory.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; +import java.nio.Buffer; import java.nio.ByteBuffer; /** @@ -60,30 +61,13 @@ public interface DataBufferFactory { DataBuffer create(DataBuffer underlyingBuffer, long offset, long length); /** - * Create int buffer - * @param buffer + * Creates a DataBuffer from java.nio.ByteBuffer + * @param underlyingBuffer + * @param offset * @param length * @return */ - DataBuffer createInt(long offset, ByteBuffer buffer, int length); - - /** - * Create a float data buffer - * @param buffer - * @param length - * @return - */ - DataBuffer createFloat(long offset, ByteBuffer buffer, int length); - - /** - * Creates a double data buffer - * @param buffer - * @param length - * @return - */ - DataBuffer createDouble(long offset, ByteBuffer buffer, int length); - - DataBuffer createLong(ByteBuffer buffer, int length); + DataBuffer create(ByteBuffer underlyingBuffer, DataType type, long length, long offset); /** * Create a double data buffer @@ -289,31 +273,6 @@ public interface DataBufferFactory { */ DataBuffer createInt(long offset, float[] data, boolean copy); - - /** - * Create int buffer - * @param buffer - * @param length - * @return - */ - DataBuffer createInt(ByteBuffer buffer, int length); - - /** - * Create a float data buffer - * @param buffer - * @param length - * @return - */ - DataBuffer createFloat(ByteBuffer buffer, int length); - - /** - * Creates a double data buffer - * @param buffer - * @param length - * @return - */ - DataBuffer createDouble(ByteBuffer buffer, int length); - /** * Create a double data buffer * @@ -459,22 +418,6 @@ public interface DataBufferFactory { DataBuffer createDouble(double[] data); - /** - * Create a double buffer - * @param data - * @param length - * @return - */ - DataBuffer createDouble(byte[] data, int length); - - /** - * Create a double buffer - * @param data - * @param length - * @return - */ - DataBuffer createFloat(byte[] data, int length); - /** * Creates a float data buffer * @@ -816,14 +759,6 @@ public interface DataBufferFactory { */ DataBuffer createHalf(int[] data); - /** - * Creates a half-precision data buffer - * - * @param data the data to create the buffer from - * @return the new buffer - */ - DataBuffer createHalf(long offset, byte[] data, int length); - /** * Creates a half-precision data buffer * @@ -831,22 +766,6 @@ public interface DataBufferFactory { */ DataBuffer createHalf(long offset, int length); - /** - * Creates a half-precision data buffer - * - * @return the new buffer - */ - DataBuffer createHalf(ByteBuffer buffer, int length); - - /** - * Creates a half-precision data buffer - * - * @param data - * @param length - * @return - */ - DataBuffer createHalf(byte[] data, int length); - Class intBufferClass(); @@ -858,4 +777,5 @@ public interface DataBufferFactory { Class doubleBufferClass(); + DataBuffer createUtf8Buffer(byte[] data, long product); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/util/AllocUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/util/AllocUtil.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/util/AllocUtil.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/util/AllocUtil.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/util/DataTypeUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/util/DataTypeUtil.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/util/DataTypeUtil.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/util/DataTypeUtil.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java index c6e5bf904..107a68dd4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java @@ -89,6 +89,11 @@ public class CompressedDataBuffer extends BaseDataBuffer { // no-op } + @Override + public Pointer addressPointer() { + return pointer; + } + /** * Drop-in replacement wrapper for BaseDataBuffer.read() method, aware of CompressedDataBuffer * @param s @@ -194,6 +199,15 @@ public class CompressedDataBuffer extends BaseDataBuffer { */ @Override public DataBuffer create(int[] data) { - throw new UnsupportedOperationException("This operation isn't supported for CompressedDataBuffer"); + throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer"); + } + + public void pointerIndexerByCurrentType(DataType currentType) { + throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer"); + } + + @Override + public DataBuffer reallocate(long length) { + throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer"); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java index 9c0645156..ae26633e4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java @@ -98,7 +98,7 @@ public class Convolution { .build(); Nd4j.getExecutioner().execAndReturn(col2Im); - return col2Im.outputArguments()[0]; + return col2Im.outputArguments().get(0); } public static INDArray col2im(INDArray col, INDArray z, int sH, int sW, int pH, int pW, int kH, int kW, @@ -187,7 +187,7 @@ public class Convolution { .build()).build(); Nd4j.getExecutioner().execAndReturn(im2col); - return im2col.outputArguments()[0]; + return im2col.outputArguments().get(0); } public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dH, int dW, boolean isSameMode, @@ -208,7 +208,7 @@ public class Convolution { .build()).build(); Nd4j.getExecutioner().execAndReturn(im2col); - return im2col.outputArguments()[0]; + return im2col.outputArguments().get(0); } /** @@ -298,7 +298,7 @@ public class Convolution { .build()).build(); Nd4j.getExecutioner().execAndReturn(im2col); - return im2col.outputArguments()[0]; + return im2col.outputArguments().get(0); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Environment.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Environment.java new file mode 100644 index 000000000..1b788220a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Environment.java @@ -0,0 +1,120 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.factory; + +/** + * ND4J backend Environment instance + * + * @author Alex Black + */ +public interface Environment { + + /** BLAS major version number (if applicable) */ + int blasMajorVersion(); + /** BLAS minor version number (if applicable) */ + int blasMinorVersion(); + /** BLAS patch version number (if applicable) */ + int blasPatchVersion(); + + /** Returns true if ND4J is set to verbose mode */ + boolean isVerbose(); + /** Set verbose mode */ + void setVerbose(boolean reallyVerbose); + /** Returns true if ND4J is set to debug mode */ + boolean isDebug(); + /** Returns true if ND4J is set to profiling mode */ + boolean isProfiling(); + /** Returns true if ND4J is set to detecting leaks mode */ + boolean isDetectingLeaks(); + /** Returns true if ND4J is set to debug and verbose mode */ + boolean isDebugAndVerbose(); + + /** Set debug mode */ + void setDebug( boolean reallyDebug); + /** Set profiling mode */ + void setProfiling( boolean reallyProfile); + /** Set leaks detection mode */ + void setLeaksDetector( boolean reallyDetect); + /** Returns true if helpers (cuDNN, DNNL/MKLDNN etc) are allowed */ + boolean helpersAllowed(); + /** Set whether helpers (cuDNN, DNNL/MKLDNN etc) are allowed */ + void allowHelpers(boolean reallyAllow); + + /** Returns the TAD (tensor along dimension) threshold for ops */ + int tadThreshold(); + /** Set the TAD (tensor along dimension) threshold for ops */ + void setTadThreshold(int threshold); + + /** Returns the elementwise threshold for ops */ + int elementwiseThreshold(); + /** Set the elementwise threshold for ops */ + void setElementwiseThreshold(int threshold); + + /** Returns the maximum number of threads for C++ op execution (if applicable) */ + int maxThreads(); + /** Set the maximum number of threads for C++ op execution (if applicable) */ + void setMaxThreads(int max); + + /** Returns the maximum number of master threads for C++ op execution (if applicable) */ + int maxMasterThreads(); + /** Set the maximum number of master threads for C++ op execution (if applicable) */ + void setMaxMasterThreads(int max); + + /** Set the maximum primary memory */ + void setMaxPrimaryMemory(long maxBytes); + /** Set the maximum special memory */ + void setMaxSpecialMemory(long maxBytes); + /** Set the maximum device memory */ + void setMaxDeviceMemory(long maxBytes); + + /** Return true if the backend is a CPU backend, or false otherwise */ + boolean isCPU(); + + /** + * This method allows to set memory limit for a specific group of devices. I.e. CUDA or CPU + * @param group + * @param numBytes + */ + void setGroupLimit(int group, long numBytes); + + /** + * This method allows to set memory limit for a specific device. I.e. GPU_0 + * @param deviceId + * @param numBytes + */ + void setDeviceLimit(int deviceId, long numBytes); + + /** + * This method returns current group limit + * @param group + * @return + */ + long getGroupLimit(int group); + + /** + * This method returns current device limit + * @param deviceId + * @return + */ + long getDeviceLimit(int deviceId); + + /** + * This method returns current allocated amount for a specific device. I.e. GPU_0 + * @param deviceId + * @return + */ + long getDeviceCouner(int deviceId); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 2e2efadda..dae946dba 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -40,7 +40,6 @@ import org.nd4j.graph.FlatArray; import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.*; import org.nd4j.linalg.api.buffer.factory.DataBufferFactory; -import org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.BasicAffinityManager; @@ -537,6 +536,14 @@ public class Nd4j { return ret; } + /** + * Get the backend Environment instance + * @return The backend Environment instance + */ + public static Environment getEnvironment(){ + return backend.getEnvironment(); + } + /** * Get the operation executioner instance. * @@ -1036,16 +1043,7 @@ public class Nd4j { * @return the created buffer */ public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length, long offset) { - switch (type) { - case INT: - return DATA_BUFFER_FACTORY_INSTANCE.createInt(offset, buffer, length); - case DOUBLE: - return DATA_BUFFER_FACTORY_INSTANCE.createDouble(offset, buffer, length); - case FLOAT: - return DATA_BUFFER_FACTORY_INSTANCE.createFloat(offset, buffer, length); - default: - throw new IllegalArgumentException("Illegal opType " + type); - } + return DATA_BUFFER_FACTORY_INSTANCE.create(buffer, type, length, offset); } /** @@ -1328,38 +1326,9 @@ public class Nd4j { * @return the created buffer */ public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) { - switch (type) { - case INT: - return DATA_BUFFER_FACTORY_INSTANCE.createInt(buffer, length); - case LONG: - return DATA_BUFFER_FACTORY_INSTANCE.createLong(buffer, length); - case DOUBLE: - return DATA_BUFFER_FACTORY_INSTANCE.createDouble(buffer, length); - case FLOAT: - return DATA_BUFFER_FACTORY_INSTANCE.createFloat(buffer, length); - case HALF: - return DATA_BUFFER_FACTORY_INSTANCE.createHalf(buffer, length); - default: - throw new IllegalArgumentException("Illegal opType " + type); - } + return createBuffer(buffer, type, length, 0); } - /** - * Create a buffer based on the data opType - * - * @param data the data to create the buffer with - * @return the created buffer - */ - public static DataBuffer createBuffer(byte[] data, int length) { - DataBuffer ret; - if (dataType() == DataType.DOUBLE) - ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, length); - else if (dataType() == DataType.HALF) - ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(data, length); - else - ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(data, length); - return ret; - } /** * Create a buffer equal of length prod(shape) @@ -2198,6 +2167,7 @@ public class Nd4j { private static String writeStringForArray(INDArray write) { if(write.isView() || !Shape.hasDefaultStridesForShape(write)) write = write.dup(); + String format = "0.000000000000000000E0"; return "{\n" + @@ -3919,16 +3889,6 @@ public class Nd4j { return create(shape, stride); } - /** - * Creates an ndarray with the specified shape - * - * @param rows the rows of the ndarray - * @param columns the columns of the ndarray - * @return the instance - */ - public static INDArray create(int rows, int columns) { - return create(rows, columns, order()); - } /** * Creates an ndarray with the specified shape @@ -4378,13 +4338,6 @@ public class Nd4j { return createUninitialized(shape, Nd4j.order()); } - /** - * See {@link #createUninitialized(long)} - */ - public static INDArray createUninitialized(int length) { - return createUninitialized((long)length); - } - /** * This method creates an *uninitialized* ndarray of specified length and default ordering. * @@ -4420,37 +4373,6 @@ public class Nd4j { ////////////////////// OTHER /////////////////////////////// - /** - * Creates a 2D array with specified number of rows, columns initialized with zero. - * - * @param rows number of rows. - * @param columns number of columns. - * @return the created array. - */ - public static INDArray zeros(long rows, long columns) { - return INSTANCE.zeros(rows, columns); - } - - /** - * Creates a 1D array with the specified number of columns initialized with zero. - * - * @param columns number of columns. - * @return the created array - */ - public static INDArray zeros(int columns) { - return INSTANCE.zeros(columns); - } - - /** - * Creates a 1D array with the specified data tyoe and number of columns initialized with zero. - * - * @param dataType data type. - * @param columns number of columns. - * @return the created array. - */ - public static INDArray zeros(DataType dataType, int columns) { - return INSTANCE.create(dataType, new long[]{columns}, 'c', Nd4j.getMemoryManager().getCurrentWorkspace()); - } /** * Creates an array with the specified data tyoe and shape initialized with zero. @@ -4460,7 +4382,10 @@ public class Nd4j { * @return the created array. */ public static INDArray zeros(DataType dataType, @NonNull long... shape) { - return INSTANCE.create(dataType, shape, 'c', Nd4j.getMemoryManager().getCurrentWorkspace()); + if(shape.length == 0) + return Nd4j.scalar(dataType, 0); + + return INSTANCE.create(dataType, shape, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -4580,31 +4505,6 @@ public class Nd4j { return INSTANCE.valueArrayOf(rows, columns, value); } - /** - * Creates a row vector with the specified number of columns - * - * @param rows the number of rows in the matrix - * @param columns the columns of the ndarray - * @return the created ndarray - */ - public static INDArray ones(int rows, int columns) { - return INSTANCE.ones(rows, columns); - } - - /** - * Create a 2D array with the given rows, columns and data type initialised with ones. - * - * @param dataType data type - * @param rows rows of the new array. - * @param columns columns of the new arrau. - * @return the created array - */ - public static INDArray ones(DataType dataType, int rows, int columns) { - INDArray ret = INSTANCE.createUninitialized(dataType, new long[]{rows, columns}, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace()); - ret.assign(1); - return ret; - } - /** * Empty like * @@ -4809,8 +4709,7 @@ public class Nd4j { for (int idx : indexes) { if (idx < 0 || idx >= source.shape()[source.rank() - sourceDimension - 1]) { - throw new IllegalStateException( - "Index can't be < 0 and >= " + source.shape()[source.rank() - sourceDimension - 1]); + throw new IllegalStateException("Index can't be < 0 and >= " + source.shape()[source.rank() - sourceDimension - 1]); } } @@ -5178,7 +5077,7 @@ public class Nd4j { pp.toString(NDARRAY_FACTORY_CLASS)); Class convolutionInstanceClazz = (Class) Class .forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName())); - String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName()); + String defaultName = pp.toString(DATA_BUFFER_OPS, "org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory"); Class dataBufferFactoryClazz = (Class) Class .forName(pp.toString(DATA_BUFFER_OPS, defaultName)); Class shapeInfoProviderClazz = (Class) Class @@ -5863,7 +5762,7 @@ public class Nd4j { arr[e] = sb.get(e + pos); } - val buffer = new Utf8Buffer(arr, prod); + val buffer = DATA_BUFFER_FACTORY_INSTANCE.createUtf8Buffer(arr, prod); return Nd4j.create(buffer, shapeOf); } catch (Exception e) { throw new RuntimeException(e); diff --git a/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java similarity index 98% rename from nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java index ec4739b86..7575c1238 100644 --- a/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -143,6 +144,8 @@ public abstract class Nd4jBackend { */ public abstract Class getNDArrayClass(); + public abstract Environment getEnvironment(); + /** * Loads the best available backend. diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/AllocationsTracker.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/AllocationsTracker.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/Deallocatable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/Deallocatable.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/Deallocatable.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/Deallocatable.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/Deallocator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/Deallocator.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/Deallocator.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/Deallocator.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/DeviceAllocationsTracker.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/DeviceAllocationsTracker.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryWorkspace.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspace.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryWorkspace.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryWorkspaceManager.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryWorkspaceManager.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/conf/WorkspaceConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/conf/WorkspaceConfiguration.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/conf/WorkspaceConfiguration.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/conf/WorkspaceConfiguration.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java index 30c68d578..9fae57705 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java @@ -30,6 +30,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; /** * This class provides unified management for Deallocatable resources @@ -43,6 +44,8 @@ public class DeallocatorService { private Map referenceMap = new ConcurrentHashMap<>(); private List>> deviceMap = new ArrayList<>(); + private AtomicLong counter = new AtomicLong(0); + public DeallocatorService() { // we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); @@ -69,6 +72,10 @@ public class DeallocatorService { } } + public long nextValue() { + return counter.incrementAndGet(); + } + /** * This method adds Deallocatable object instance to tracking system * diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/AllocationKind.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/AllocationKind.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/AllocationPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/AllocationPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/DebugMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/DebugMode.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/DebugMode.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/DebugMode.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/LearningPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/LearningPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/LearningPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/LearningPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/LocationPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/LocationPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/LocationPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/LocationPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/MemoryKind.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/MemoryKind.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/MemoryKind.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/MemoryKind.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/MirroringPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/MirroringPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/MirroringPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/MirroringPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/ResetPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/ResetPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/ResetPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/ResetPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/SpillPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/SpillPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/SpillPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/SpillPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/ImmortalFloatPointer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/ImmortalFloatPointer.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/ImmortalFloatPointer.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/ImmortalFloatPointer.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PagedPointer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/PagedPointer.java similarity index 97% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PagedPointer.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/PagedPointer.java index 74d4bbca2..35d36ec64 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PagedPointer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/PagedPointer.java @@ -63,7 +63,7 @@ public class PagedPointer extends Pointer { public PagedPointer(Pointer pointer, long capacity) { this.originalPointer = pointer; - this.address = pointer.address(); + this.address = pointer == null ? 0 : pointer.address(); this.capacity = capacity; this.limit = capacity; diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PointersPair.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/PointersPair.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PointersPair.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/PointersPair.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java index 939708291..f3b539d8c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java @@ -18,6 +18,7 @@ package org.nd4j.list; import lombok.NonNull; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -46,7 +47,12 @@ public class NDArrayList extends BaseNDArrayList { * @param size the initial size of the array */ public NDArrayList(int size) { - this.container = Nd4j.create(10L); + this(DataType.DOUBLE, size); + } + + public NDArrayList(DataType dataType, int size) { + Preconditions.checkState(size >= 0, "Size must be non-negative - got %s", size); + this.container = Nd4j.create(dataType, Math.max(10L, size)); this.size = size; } @@ -84,6 +90,7 @@ public class NDArrayList extends BaseNDArrayList { * directly, this gives you the relevant subset that reflects the content of the list) * @return the view of the underlying ndarray relative to the collection's real size */ + @Override public INDArray array() { if(isEmpty()) { throw new ND4JIllegalStateException("Array is empty!"); @@ -137,6 +144,8 @@ public class NDArrayList extends BaseNDArrayList { return true; } + + @Override public boolean remove(Object o) { int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java index 5e966f850..c395959d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java @@ -17,10 +17,10 @@ package org.nd4j.serde.jackson.shaded; -import org.nd4j.linalg.api.buffer.Utf8Buffer; + +import lombok.val; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.serde.base64.Nd4jBase64; import org.nd4j.shade.jackson.core.JsonGenerator; import org.nd4j.shade.jackson.databind.JsonSerializer; import org.nd4j.shade.jackson.databind.SerializerProvider; @@ -77,10 +77,9 @@ public class NDArrayTextSerializer extends JsonSerializer { jg.writeNumber(v); break; case UTF8: - Utf8Buffer utf8B = ((Utf8Buffer)arr.data()); - long n = utf8B.getNumWords(); + val n = arr.length(); for( int j=0; j - + org.nd4j nd4j-api diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index 741978a3c..1d1b837e7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -16,11 +16,8 @@ package org.nd4j.nativeblas; -import lombok.val; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.annotation.Cast; -import org.bytedeco.javacpp.indexer.LongIndexer; -import org.nd4j.linalg.api.buffer.Utf8Buffer; /** @@ -53,14 +50,12 @@ public interface NativeOps { */ void execIndexReduceScalar(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dX, @Cast("Nd4jLong *") LongPointer dXShapeInfo, Pointer extraParams, - Pointer z, + OpaqueDataBuffer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, - Pointer dZ, @Cast("Nd4jLong *") LongPointer dZShapeInfo); /** @@ -75,17 +70,16 @@ public interface NativeOps { */ void execIndexReduce(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dX, @Cast("Nd4jLong *") LongPointer dXShapeInfo, Pointer extraParams, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, - Pointer dResult, @Cast("Nd4jLong *") LongPointer dResultShapeInfoBuffer, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape); /** * @param opNum @@ -100,38 +94,34 @@ public interface NativeOps { */ void execBroadcast(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer y, + OpaqueDataBuffer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, - Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape); void execBroadcastBool(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer y, + OpaqueDataBuffer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, - Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape); /** @@ -146,33 +136,27 @@ public interface NativeOps { */ void execPairwiseTransform(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer y, + OpaqueDataBuffer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, - Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, Pointer extraParams); void execPairwiseTransformBool(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer y, + OpaqueDataBuffer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, - Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, Pointer extraParams); @@ -186,53 +170,45 @@ public interface NativeOps { */ void execReduceFloat(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParams, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo); void execReduceSame(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParams, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo); void execReduceBool(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParams, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo); void execReduceLong(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParams, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo); /** @@ -245,60 +221,56 @@ public interface NativeOps { */ void execReduceFloat2(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParams, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape); void execReduceSame2(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParams, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape); void execReduceBool2(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParams, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape); void execReduceLong2(PointerPointer extraPointers, int opNum, - Pointer x, + OpaqueDataBuffer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParams, - Pointer result, + OpaqueDataBuffer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape); /** * @param opNum @@ -312,13 +284,16 @@ public interface NativeOps { */ void execReduce3(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParamsVals, - Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, - Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo); + OpaqueDataBuffer y, + @Cast("Nd4jLong *") LongPointer yShapeInfo, + @Cast("Nd4jLong *") LongPointer dyShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfo, + @Cast("Nd4jLong *") LongPointer dresultShapeInfo); /** * @param opNum @@ -329,13 +304,16 @@ public interface NativeOps { * @param yShapeInfo */ void execReduce3Scalar(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer extraParamsVals, - Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, - Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, - Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, - Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo); + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + Pointer extraParamsVals, + OpaqueDataBuffer y, + @Cast("Nd4jLong *") LongPointer yShapeInfo, + @Cast("Nd4jLong *") LongPointer dyShapeInfo, + OpaqueDataBuffer z, + @Cast("Nd4jLong *") LongPointer zShapeInfo, + @Cast("Nd4jLong *") LongPointer dzShapeInfo); /** * @param opNum @@ -351,29 +329,37 @@ public interface NativeOps { */ void execReduce3Tad(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParamsVals, - Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, - Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape, - @Cast("Nd4jLong *") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets, - @Cast("Nd4jLong *") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer yTadOffsets); + OpaqueDataBuffer y, + @Cast("Nd4jLong *") LongPointer yShapeInfo, + @Cast("Nd4jLong *") LongPointer dyShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, + @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer, + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape, + @Cast("Nd4jLong *") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets, + @Cast("Nd4jLong *") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer yTadOffsets); void execReduce3All(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParamsVals, - Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, - Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape, + OpaqueDataBuffer y, + @Cast("Nd4jLong *") LongPointer yShapeInfo, + @Cast("Nd4jLong *") LongPointer dyShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, + @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer, + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape, @Cast("Nd4jLong *") LongPointer xTadShape, @Cast("Nd4jLong *") LongPointer xOffsets, @Cast("Nd4jLong *") LongPointer yTadShape, @@ -391,22 +377,28 @@ public interface NativeOps { */ void execScalar(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, - Pointer scalar, @Cast("Nd4jLong *") LongPointer scalarShapeInfo, - Pointer dscalar, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfo, + @Cast("Nd4jLong *") LongPointer dresultShapeInfo, + OpaqueDataBuffer scalar, + @Cast("Nd4jLong *") LongPointer scalarShapeInfo, + @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, Pointer extraParams); void execScalarBool(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, - Pointer scalar, @Cast("Nd4jLong *") LongPointer scalarShapeInfo, - Pointer dscalar, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfo, + @Cast("Nd4jLong *") LongPointer dresultShapeInfo, + OpaqueDataBuffer scalar, + @Cast("Nd4jLong *") LongPointer scalarShapeInfo, + @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, Pointer extraParams); /** @@ -418,11 +410,13 @@ public interface NativeOps { */ void execSummaryStatsScalar(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParams, - Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, - Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo, + OpaqueDataBuffer z, + @Cast("Nd4jLong *") LongPointer zShapeInfo, + @Cast("Nd4jLong *") LongPointer dzShapeInfo, boolean biasCorrected); /** @@ -436,11 +430,13 @@ public interface NativeOps { */ void execSummaryStats(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, Pointer extraParams, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfo, + @Cast("Nd4jLong *") LongPointer dresultShapeInfo, boolean biasCorrected); /** @@ -454,17 +450,20 @@ public interface NativeOps { * @param dimensionLength */ void execSummaryStatsTad(PointerPointer extraPointers, - int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer extraParams, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape, - boolean biasCorrected, - @Cast("Nd4jLong *") LongPointer tadShapeInfo, - @Cast("Nd4jLong *") LongPointer tadOffsets); + int opNum, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + Pointer extraParams, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, + @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer, + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape, + boolean biasCorrected, + @Cast("Nd4jLong *") LongPointer tadShapeInfo, + @Cast("Nd4jLong *") LongPointer tadOffsets); /** @@ -478,43 +477,53 @@ public interface NativeOps { */ void execTransformFloat(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfo, + @Cast("Nd4jLong *") LongPointer dresultShapeInfo, Pointer extraParams); void execTransformSame(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfo, + @Cast("Nd4jLong *") LongPointer dresultShapeInfo, Pointer extraParams); void execTransformStrict(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, - Pointer extraParams); + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfo, + @Cast("Nd4jLong *") LongPointer dresultShapeInfo, + Pointer extraParams); void execTransformBool(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, - Pointer extraParams); + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfo, + @Cast("Nd4jLong *") LongPointer dresultShapeInfo, + Pointer extraParams); void execTransformAny(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, - Pointer extraParams); + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer result, + @Cast("Nd4jLong *") LongPointer resultShapeInfo, + @Cast("Nd4jLong *") LongPointer dresultShapeInfo, + Pointer extraParams); /** * ScalarOp along dimension @@ -532,31 +541,43 @@ public interface NativeOps { */ void execScalarTad(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, - Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo, - Pointer scalars, @Cast("Nd4jLong *") LongPointer scalarShapeInfo, - Pointer dscalars, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer z, + @Cast("Nd4jLong *") LongPointer zShapeInfo, + @Cast("Nd4jLong *") LongPointer dzShapeInfo, + OpaqueDataBuffer scalars, + @Cast("Nd4jLong *") LongPointer scalarShapeInfo, + @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape, - @Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets, - @Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ); + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape, + @Cast("Nd4jLong *") LongPointer tadShapeInfo, + @Cast("Nd4jLong *") LongPointer tadOffsets, + @Cast("Nd4jLong *") LongPointer tadShapeInfoZ, + @Cast("Nd4jLong *") LongPointer tadOffsetsZ); void execScalarBoolTad(PointerPointer extraPointers, int opNum, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, - Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo, - Pointer scalars, @Cast("Nd4jLong *") LongPointer scalarShapeInfo, - Pointer dscalars, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer z, + @Cast("Nd4jLong *") LongPointer zShapeInfo, + @Cast("Nd4jLong *") LongPointer dzShapeInfo, + OpaqueDataBuffer scalars, + @Cast("Nd4jLong *") LongPointer scalarShapeInfo, + @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape, - @Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets, - @Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ); + OpaqueDataBuffer hDimension, + @Cast("Nd4jLong *") LongPointer hDimensionShape, + @Cast("Nd4jLong *") LongPointer dDimensionShape, + @Cast("Nd4jLong *") LongPointer tadShapeInfo, + @Cast("Nd4jLong *") LongPointer tadOffsets, + @Cast("Nd4jLong *") LongPointer tadShapeInfoZ, + @Cast("Nd4jLong *") LongPointer tadOffsetsZ); void specialConcat(PointerPointer extraPointers, @@ -675,10 +696,12 @@ public interface NativeOps { /////////////// void pullRows(PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, - Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + OpaqueDataBuffer z, + @Cast("Nd4jLong *") LongPointer zShapeInfo, + @Cast("Nd4jLong *") LongPointer dzShapeInfo, long n, @Cast("Nd4jLong *") LongPointer indexes, @Cast("Nd4jLong *") LongPointer tadShapeInfo, @@ -777,28 +800,34 @@ public interface NativeOps { void execRandom(PointerPointer extraPointers, int opNum, Pointer state, - Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer, - Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer, + OpaqueDataBuffer z, + @Cast("Nd4jLong *") LongPointer zShapeBuffer, + @Cast("Nd4jLong *") LongPointer dzShapeBuffer, Pointer extraArguments); void execRandom3(PointerPointer extraPointers, int opNum, Pointer state, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeBuffer, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeBuffer, - Pointer y, @Cast("Nd4jLong *") LongPointer yShapeBuffer, - Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeBuffer, - Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer, - Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeBuffer, + @Cast("Nd4jLong *") LongPointer dxShapeBuffer, + OpaqueDataBuffer y, + @Cast("Nd4jLong *") LongPointer yShapeBuffer, + @Cast("Nd4jLong *") LongPointer dyShapeBuffer, + OpaqueDataBuffer z, + @Cast("Nd4jLong *") LongPointer zShapeBuffer, + @Cast("Nd4jLong *") LongPointer dzShapeBuffer, Pointer extraArguments); void execRandom2(PointerPointer extraPointers, int opNum, Pointer state, - Pointer x, @Cast("Nd4jLong *") LongPointer xShapeBuffer, - Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeBuffer, - Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer, - Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer, + OpaqueDataBuffer x, + @Cast("Nd4jLong *") LongPointer xShapeBuffer, + @Cast("Nd4jLong *") LongPointer dxShapeBuffer, + OpaqueDataBuffer z, + @Cast("Nd4jLong *") LongPointer zShapeBuffer, + @Cast("Nd4jLong *") LongPointer dzShapeBuffer, Pointer extraArguments); //////////////////// @@ -967,11 +996,13 @@ public interface NativeOps { void tear(PointerPointer extras, - Pointer tensor, @Cast("Nd4jLong *") LongPointer xShapeInfo, - Pointer dtensor, @Cast("Nd4jLong *") LongPointer dxShapeInfo, - PointerPointer targets, @Cast("Nd4jLong *") LongPointer zShapeInfo, - @Cast("Nd4jLong *") LongPointer tadShapeInfo, - @Cast("Nd4jLong *") LongPointer tadOffsets); + OpaqueDataBuffer tensor, + @Cast("Nd4jLong *") LongPointer xShapeInfo, + @Cast("Nd4jLong *") LongPointer dxShapeInfo, + PointerPointer targets, + @Cast("Nd4jLong *") LongPointer zShapeInfo, + @Cast("Nd4jLong *") LongPointer tadShapeInfo, + @Cast("Nd4jLong *") LongPointer tadOffsets); long encodeBitmap(PointerPointer extraPointers, Pointer dx, LongPointer xShapeInfo, long N, IntPointer dz, float threshold); @@ -1121,10 +1152,13 @@ public interface NativeOps { void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); + void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments); void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments); void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow); + void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride); void deleteGraphContext(OpaqueContext ptr); OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed); @@ -1161,4 +1195,27 @@ public interface NativeOps { boolean isMinimalRequirementsMet(); boolean isOptimalRequirementsMet(); + + + OpaqueDataBuffer allocateDataBuffer(long elements, int dataType, boolean allocateBoth); + OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, long length, long offset); + Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer); + Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer); + void dbExpandBuffer(OpaqueDataBuffer dataBuffer, long elements); + void dbAllocatePrimaryBuffer(OpaqueDataBuffer dataBuffer); + void dbAllocateSpecialBuffer(OpaqueDataBuffer dataBuffer); + void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, Pointer primaryBuffer, long numBytes); + void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, Pointer specialBuffer, long numBytes); + void dbSyncToSpecial(OpaqueDataBuffer dataBuffer); + void dbSyncToPrimary(OpaqueDataBuffer dataBuffer); + void dbTickHostRead(OpaqueDataBuffer dataBuffer); + void dbTickHostWrite(OpaqueDataBuffer dataBuffer); + void dbTickDeviceRead(OpaqueDataBuffer dataBuffer); + void dbTickDeviceWrite(OpaqueDataBuffer dataBuffer); + void deleteDataBuffer(OpaqueDataBuffer dataBuffer); + void dbClose(OpaqueDataBuffer dataBuffer); + int dbLocality(OpaqueDataBuffer dataBuffer); + int dbDeviceId(OpaqueDataBuffer dataBuffer); + void dbSetDeviceId(OpaqueDataBuffer dataBuffer, int deviceId); + void dbExpand(OpaqueDataBuffer dataBuffer, long newLength); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java index 5de827d1a..fa92f94f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java @@ -21,6 +21,7 @@ import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.Pointer; import org.nd4j.config.ND4JEnvironmentVars; +import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.blas.Blas; @@ -52,7 +53,10 @@ public abstract class Nd4jBlas implements Blas { setMaxThreads(numThreads); } - log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads()); + String logInit = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION); + if(logInit == null || logInit.isEmpty() || Boolean.parseBoolean(logInit)) { + log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads()); + } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java new file mode 100644 index 000000000..d5a84eac3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java @@ -0,0 +1,216 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.bytedeco.javacpp.Pointer; +import org.nd4j.linalg.api.buffer.DataType; + +import java.util.concurrent.locks.LockSupport; + +/** + * This class is a opaque pointer to InteropDataBuffer, used for Java/C++ interop related to INDArray DataBuffer + * + * @author saudet + * @author raver119@gmail.com + */ +@Slf4j +public class OpaqueDataBuffer extends Pointer { + // TODO: make this configurable + private static final int MAX_TRIES = 5; + + public OpaqueDataBuffer(Pointer p) { super(p); } + + /** + * This method allocates new InteropDataBuffer and returns pointer to it + * @param numElements + * @param dataType + * @param allocateBoth + * @return + */ + public static OpaqueDataBuffer allocateDataBuffer(long numElements, @NonNull DataType dataType, boolean allocateBoth) { + OpaqueDataBuffer buffer = null; + int ec = 0; + String em = null; + + for (int t = 0; t < MAX_TRIES; t++) { + try { + // try to allocate data buffer + buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(numElements, dataType.toInt(), allocateBoth); + + // check error code + ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode(); + if (ec != 0) { + em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage(); + + // if allocation failed it might be caused by casual OOM, so we'll try GC + System.gc(); + + // sleeping for 50ms + Thread.sleep(50); + } else { + // just return the buffer + return buffer; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + // if MAX_TRIES is over, we'll just throw an exception + throw new RuntimeException("Allocation failed: [" + em + "]"); + } + + /** + * This method expands buffer, and copies content to the new buffer + * + * PLEASE NOTE: if InteropDataBuffer doesn't own actual buffers - original pointers won't be released + * @param numElements + */ + public void expand(long numElements) { + int ec = 0; + String em = null; + + for (int t = 0; t < MAX_TRIES; t++) { + try { + // try to expand the buffer + NativeOpsHolder.getInstance().getDeviceNativeOps().dbExpand(this, numElements); + + // check error code + ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode(); + if (ec != 0) { + em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage(); + + // if expansion failed it might be caused by casual OOM, so we'll try GC + System.gc(); + + Thread.sleep(50); + } else { + // just return + return; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + // if MAX_TRIES is over, we'll just throw an exception + throw new RuntimeException("DataBuffer expansion failed: [" + em + "]"); + } + + /** + * This method creates a view out of this InteropDataBuffer + * + * @param bytesLength + * @param bytesOffset + * @return + */ + public OpaqueDataBuffer createView(long bytesLength, long bytesOffset) { + OpaqueDataBuffer buffer = null; + int ec = 0; + String em = null; + + for (int t = 0; t < MAX_TRIES; t++) { + try { + buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().dbCreateView(this, bytesLength, bytesOffset); + + // check error code + ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode(); + if (ec != 0) { + em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage(); + + // if view creation failed it might be caused by casual OOM, so we'll try GC + System.gc(); + + // sleeping to let gc kick in + Thread.sleep(50); + } else { + // just return + return buffer; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + // if MAX_TRIES is over, we'll just throw an exception + throw new RuntimeException("DataBuffer expansion failed: [" + em + "]"); + } + + /** + * This method returns pointer to linear buffer, primary one. + * @return + */ + public Pointer primaryBuffer() { + return NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(this); + } + + /** + * This method returns pointer to special buffer, device one, if any. + * @return + */ + public Pointer specialBuffer() { + return NativeOpsHolder.getInstance().getDeviceNativeOps().dbSpecialBuffer(this); + } + + /** + * This method returns deviceId of this DataBuffer + * @return + */ + public int deviceId() { + return NativeOpsHolder.getInstance().getDeviceNativeOps().dbDeviceId(this); + } + + /** + * This method allows to set external pointer as primary buffer. + * + * PLEASE NOTE: if InteropDataBuffer owns current memory buffer, it will be released + * @param ptr + * @param numElements + */ + public void setPrimaryBuffer(Pointer ptr, long numElements) { + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(this, ptr, numElements); + } + + /** + * This method allows to set external pointer as primary buffer. + * + * PLEASE NOTE: if InteropDataBuffer owns current memory buffer, it will be released + * @param ptr + * @param numElements + */ + public void setSpecialBuffer(Pointer ptr, long numElements) { + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(this, ptr, numElements); + } + + /** + * This method synchronizes device memory + */ + public void syncToSpecial() { + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(this); + } + + /** + * This method synchronizes host memory + */ + public void syncToPrimary() { + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(this); + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml index ec0eab208..d98c7a6d1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml @@ -253,6 +253,7 @@ ${cuda.version}-${cudnn.version}-${javacpp-presets.cuda.version} ${dependency.platform} + junit junit @@ -308,6 +310,13 @@ + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java index f673a15d7..881d1e8b2 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java @@ -19,6 +19,7 @@ package org.nd4j.jita.allocator.impl; import lombok.Getter; import lombok.NonNull; import lombok.Setter; +import lombok.val; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.garbage.GarbageBufferReference; @@ -29,9 +30,11 @@ import org.nd4j.jita.allocator.time.providers.MillisecondsProvider; import org.nd4j.jita.allocator.time.providers.OperativeProvider; import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,8 +57,8 @@ import java.util.concurrent.locks.ReentrantLock; public class AllocationPoint { private static Logger log = LoggerFactory.getLogger(AllocationPoint.class); - // thread safety is guaranteed by cudaLock - private volatile PointersPair pointerInfo; + @Getter + private OpaqueDataBuffer ptrDataBuffer; @Getter @Setter @@ -104,33 +107,27 @@ public class AllocationPoint { */ private volatile int deviceId; - public AllocationPoint() { - // + private long bytes; + + public AllocationPoint(@NonNull OpaqueDataBuffer opaqueDataBuffer, long bytes) { + ptrDataBuffer = opaqueDataBuffer; + this.bytes = bytes; + objectId = Nd4j.getDeallocatorService().nextValue(); } - public void acquireLock() { - //lock.lock(); - } - - public void releaseLock() { - //lock.unlock(); + public void setPointers(Pointer primary, Pointer special, long numberOfElements) { + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, primary, numberOfElements); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(ptrDataBuffer, special, numberOfElements); } public int getDeviceId() { - return deviceId; + return ptrDataBuffer.deviceId(); } public void setDeviceId(int deviceId) { - this.deviceId = deviceId; + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetDeviceId(ptrDataBuffer, deviceId); } - /* - We assume 1D memory chunk allocations. - */ - @Getter - @Setter - private AllocationShape shape; - private AtomicBoolean enqueued = new AtomicBoolean(false); @Getter @@ -164,7 +161,7 @@ public class AllocationPoint { } public long getNumberOfBytes() { - return shape.getNumberOfBytes(); + return bytes; } /* @@ -220,67 +217,25 @@ public class AllocationPoint { * This method returns CUDA pointer object for this allocation. * It can be either device pointer or pinned memory pointer, or null. * - * PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock * @return */ public Pointer getDevicePointer() { - if (pointerInfo == null) { - log.info("pointerInfo is null"); - return null; - } - return pointerInfo.getDevicePointer(); + return NativeOpsHolder.getInstance().getDeviceNativeOps().dbSpecialBuffer(ptrDataBuffer); } /** * This method returns CUDA pointer object for this allocation. * It can be either device pointer or pinned memory pointer, or null. * - * PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock * @return */ public Pointer getHostPointer() { - if (pointerInfo == null) - return null; - - return pointerInfo.getHostPointer(); - } - - /** - * This method sets CUDA pointer for this allocation. - * It can be either device pointer, or pinned memory pointer, or null. - * - * PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock - * @param pointerInfo CUDA pointers wrapped into DevicePointerInfo - */ - public void setPointers(@NonNull PointersPair pointerInfo) { - this.pointerInfo = pointerInfo; - } - - public PointersPair getPointers() { - return this.pointerInfo; + return NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(ptrDataBuffer); } public synchronized void tickDeviceRead() { - // this.deviceTicks.incrementAndGet(); - // this.timerShort.triggerEvent(); - // this.timerLong.triggerEvent(); - //this.deviceAccessTime.set(realTimeProvider.getCurrentTime()); - this.accessDeviceRead = (timeProvider.getCurrentTime()); - } - - - /** - * Returns time, in milliseconds, when this point was accessed on host side - * - * @return - */ - public synchronized long getHostReadTime() { - return accessHostRead; - }; - - public synchronized long getHostWriteTime() { - return accessHostWrite; + NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickDeviceRead(ptrDataBuffer); } /** @@ -302,7 +257,7 @@ public class AllocationPoint { } public synchronized void tickHostRead() { - accessHostRead = (timeProvider.getCurrentTime()); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickHostRead(ptrDataBuffer); } /** @@ -310,17 +265,14 @@ public class AllocationPoint { * */ public synchronized void tickDeviceWrite() { - // deviceAccessTime.set(realTimeProvider.getCurrentTime()); - tickDeviceRead(); - accessDeviceWrite = (timeProvider.getCurrentTime()); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickDeviceWrite(ptrDataBuffer); } /** * This method sets time when this point was changed on host */ public synchronized void tickHostWrite() { - tickHostRead(); - accessHostWrite = (timeProvider.getCurrentTime()); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickHostWrite(ptrDataBuffer); } /** @@ -329,10 +281,8 @@ public class AllocationPoint { * @return true, if data is actual, false otherwise */ public synchronized boolean isActualOnHostSide() { - boolean result = accessHostWrite >= accessDeviceWrite - || accessHostRead >= accessDeviceWrite; - - return result; + val s = NativeOpsHolder.getInstance().getDeviceNativeOps().dbLocality(ptrDataBuffer); + return s <= 0; } /** @@ -341,9 +291,8 @@ public class AllocationPoint { * @return */ public synchronized boolean isActualOnDeviceSide() { - boolean result = accessDeviceWrite >= accessHostWrite - || accessDeviceRead >= accessHostWrite; - return result; + val s = NativeOpsHolder.getInstance().getDeviceNativeOps().dbLocality(ptrDataBuffer); + return s >= 0; } /** @@ -355,6 +304,6 @@ public class AllocationPoint { @Override public String toString() { - return "AllocationPoint{" + "deviceId=" + deviceId + ", objectId=" + objectId + ", shape=" + shape + '}'; + return "AllocationPoint{" + "deviceId=" + deviceId + ", objectId=" + objectId + "}"; } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index 8ec8734f7..ac35d1933 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -19,12 +19,10 @@ package org.nd4j.jita.allocator.impl; import lombok.Getter; import lombok.NonNull; import lombok.val; -import org.apache.commons.lang3.RandomUtils; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.enums.Aggressiveness; import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.garbage.GarbageBufferReference; import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.pointers.PointersPair; import org.nd4j.jita.allocator.time.Ring; @@ -37,29 +35,25 @@ import org.nd4j.jita.flow.FlowController; import org.nd4j.jita.handler.MemoryHandler; import org.nd4j.jita.handler.impl.CudaZeroHandler; import org.nd4j.jita.workspace.CudaWorkspace; -import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.memory.enums.MemoryKind; -import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.cache.ConstantHandler; import org.nd4j.linalg.compression.CompressedDataBuffer; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; +import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.nativeblas.NativeOpsHolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; -import java.lang.ref.ReferenceQueue; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.locks.LockSupport; import java.util.concurrent.locks.ReentrantReadWriteLock; /** @@ -285,16 +279,10 @@ public class AtomicAllocator implements Allocator { */ @Override public Pointer getPointer(@NonNull DataBuffer buffer, CudaContext context) { - if (buffer instanceof Utf8Buffer) - return null; - return memoryHandler.getDevicePointer(buffer, context); } public Pointer getPointer(DataBuffer buffer) { - if (buffer instanceof Utf8Buffer) - return null; - return memoryHandler.getDevicePointer(buffer, getDeviceContext()); } @@ -320,7 +308,7 @@ public class AtomicAllocator implements Allocator { public Pointer getPointer(INDArray array, CudaContext context) { // DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); if (array.isEmpty() || array.isS()) - return null; + throw new UnsupportedOperationException("Pew-pew"); return memoryHandler.getDevicePointer(array.data(), context); } @@ -372,20 +360,17 @@ public class AtomicAllocator implements Allocator { @Override public void synchronizeHostData(DataBuffer buffer) { // we don't want non-committed ops left behind - //Nd4j.getExecutioner().push(); + Nd4j.getExecutioner().commit(); - // we don't synchronize constant buffers, since we assume they are always valid on host side - if (buffer.isConstant() || buffer.dataType() == DataType.UTF8 || AtomicAllocator.getInstance().getAllocationPoint(buffer).getPointers().getHostPointer() == null) { - return; - } + val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); - // we actually need synchronization only in device-dependant environment. no-op otherwise - if (memoryHandler.isDeviceDependant()) { - val point = getAllocationPoint(buffer.getTrackingPoint()); - if (point == null) - throw new RuntimeException("AllocationPoint is NULL"); - memoryHandler.synchronizeThreadDevice(Thread.currentThread().getId(), memoryHandler.getDeviceId(), point); - } + // we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); + + val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); + + //assert oPtr.address() == cPtr.address(); + //assert buffer.address() == oPtr.address(); } @@ -446,6 +431,7 @@ public class AtomicAllocator implements Allocator { public AllocationPoint pickExternalBuffer(DataBuffer buffer) { + /** AllocationPoint point = new AllocationPoint(); Long allocId = objectsTracker.getAndIncrement(); point.setObjectId(allocId); @@ -458,6 +444,9 @@ public class AtomicAllocator implements Allocator { point.tickHostRead(); return point; + */ + + throw new UnsupportedOperationException("Pew-pew"); } /** @@ -469,69 +458,8 @@ public class AtomicAllocator implements Allocator { * @param location */ @Override - public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, - boolean initialize) { - AllocationPoint point = new AllocationPoint(); - - useTracker.set(System.currentTimeMillis()); - - // we use these longs as tracking codes for memory tracking - Long allocId = objectsTracker.getAndIncrement(); - //point.attachBuffer(buffer); - point.setObjectId(allocId); - point.setShape(requiredMemory); - /* - if (buffer instanceof CudaIntDataBuffer) { - buffer.setConstant(true); - point.setConstant(true); - } - */ - /*int numBuckets = configuration.getNumberOfGcThreads(); - int bucketId = RandomUtils.nextInt(0, numBuckets); - - GarbageBufferReference reference = - new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);*/ - //point.attachReference(reference); - point.setDeviceId(-1); - - if (buffer.isAttached()) { - long reqMem = AllocationUtils.getRequiredMemory(requiredMemory); - - // workaround for init order - getMemoryHandler().getCudaContext(); - point.setDeviceId(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - val workspace = (CudaWorkspace) Nd4j.getMemoryManager().getCurrentWorkspace(); - - val pair = new PointersPair(); - val ptrDev = workspace.alloc(reqMem, MemoryKind.DEVICE, requiredMemory.getDataType(), initialize); - - if (ptrDev != null) { - pair.setDevicePointer(ptrDev); - point.setAllocationStatus(AllocationStatus.DEVICE); - } else { - // we allocate initial host pointer only - val ptrHost = workspace.alloc(reqMem, MemoryKind.HOST, requiredMemory.getDataType(), initialize); - pair.setHostPointer(ptrHost); - - pair.setDevicePointer(ptrHost); - point.setAllocationStatus(AllocationStatus.HOST); - } - - point.setAttached(true); - - point.setPointers(pair); - } else { - // we stay naive on PointersPair, we just don't know on this level, which pointers are set. MemoryHandler will be used for that - PointersPair pair = memoryHandler.alloc(location, point, requiredMemory, initialize); - point.setPointers(pair); - } - - allocationsMap.put(allocId, point); - //point.tickHostRead(); - point.tickDeviceWrite(); - //point.setAllocationStatus(location); - return point; + public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, boolean initialize) { + throw new UnsupportedOperationException("Pew-pew"); } @@ -619,10 +547,11 @@ public class AtomicAllocator implements Allocator { */ if (point.getBuffer() == null) { purgeZeroObject(bucketId, object, point, false); - freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape())); + //freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape())); + throw new UnsupportedOperationException("Pew-pew"); - elementsDropped.incrementAndGet(); - continue; + //elementsDropped.incrementAndGet(); + //continue; } else { elementsSurvived.incrementAndGet(); } @@ -682,13 +611,14 @@ public class AtomicAllocator implements Allocator { if (point.getAllocationStatus() == AllocationStatus.DEVICE) { // we deallocate device memory purgeDeviceObject(threadId, deviceId, object, point, false); - freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape())); + //freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape())); // and we deallocate host memory, since object is dereferenced - purgeZeroObject(point.getBucketId(), object, point, false); + //purgeZeroObject(point.getBucketId(), object, point, false); - elementsDropped.incrementAndGet(); - continue; + //elementsDropped.incrementAndGet(); + //continue; + throw new UnsupportedOperationException("Pew-pew"); } ; } else { elementsSurvived.incrementAndGet(); @@ -1014,6 +944,31 @@ public class AtomicAllocator implements Allocator { this.memoryHandler.memcpy(dstBuffer, srcBuffer); } + @Override + public void tickHostWrite(DataBuffer buffer) { + getAllocationPoint(buffer).tickHostWrite(); + } + + @Override + public void tickHostWrite(INDArray array) { + getAllocationPoint(array.data()).tickHostWrite(); + } + + @Override + public void tickDeviceWrite(INDArray array) { + getAllocationPoint(array.data()).tickDeviceWrite(); + } + + @Override + public AllocationPoint getAllocationPoint(INDArray array) { + return getAllocationPoint(array.data()); + } + + @Override + public AllocationPoint getAllocationPoint(DataBuffer buffer) { + return ((BaseCudaDataBuffer) buffer).getAllocationPoint(); + } + /** * This method returns deviceId for current thread * All values >= 0 are considered valid device IDs, all values < 0 are considered stubs. @@ -1031,48 +986,6 @@ public class AtomicAllocator implements Allocator { return new CudaPointer(getDeviceId()); } - @Override - public void tickHostWrite(DataBuffer buffer) { - AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint()); - point.tickHostWrite(); - } - - @Override - public void tickHostWrite(INDArray array) { - DataBuffer buffer = - array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); - - tickHostWrite(buffer); - } - - @Override - public void tickDeviceWrite(INDArray array) { - DataBuffer buffer = - array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); - AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint()); - - point.tickDeviceWrite(); - } - - @Override - public AllocationPoint getAllocationPoint(INDArray array) { - if (array.isEmpty()) - return null; - - DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); - return getAllocationPoint(buffer); - } - - @Override - public AllocationPoint getAllocationPoint(DataBuffer buffer) { - if (buffer instanceof CompressedDataBuffer) { - log.warn("Trying to get AllocationPoint from CompressedDataBuffer"); - throw new RuntimeException("AP CDB"); - } - - return getAllocationPoint(buffer.getTrackingPoint()); - } - @Override public void registerAction(CudaContext context, INDArray result, INDArray... operands) { memoryHandler.registerAction(context, result, operands); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java index ae1ad93cd..0f65b8f00 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java @@ -23,46 +23,21 @@ import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.nd4j.linalg.api.memory.Deallocator; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; @Slf4j public class CudaDeallocator implements Deallocator { - private AllocationPoint point; + private OpaqueDataBuffer opaqueDataBuffer; public CudaDeallocator(@NonNull BaseCudaDataBuffer buffer) { - this.point = buffer.getAllocationPoint(); - if (this.point == null) - throw new RuntimeException(); + opaqueDataBuffer = buffer.getOpaqueDataBuffer(); } @Override public void deallocate() { log.trace("Deallocating CUDA memory"); - // skipping any allocation that is coming from workspace - if (point.isAttached() || point.isReleased()) { - // TODO: remove allocation point as well? - if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId())) - return; - - AtomicAllocator.getInstance().getFlowController().waitTillReleased(point); - - AtomicAllocator.getInstance().getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent()); - AtomicAllocator.getInstance().getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent()); - - AtomicAllocator.getInstance().allocationsMap().remove(point.getObjectId()); - - return; - } - - - //log.info("Purging {} bytes...", AllocationUtils.getRequiredMemory(point.getShape())); - if (point.getAllocationStatus() == AllocationStatus.HOST) { - AtomicAllocator.getInstance().purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false); - } else if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - AtomicAllocator.getInstance().purgeDeviceObject(0L, point.getDeviceId(), point.getObjectId(), point, false); - - // and we deallocate host memory, since object is dereferenced - AtomicAllocator.getInstance().purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false); - } + NativeOpsHolder.getInstance().getDeviceNativeOps().deleteDataBuffer(opaqueDataBuffer); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaStream_t.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaStream_t.java index 8d78ee950..7d9bfb629 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaStream_t.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaStream_t.java @@ -17,6 +17,7 @@ package org.nd4j.jita.allocator.pointers.cuda; import lombok.NonNull; +import lombok.val; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.linalg.exception.ND4JException; @@ -37,8 +38,9 @@ public class cudaStream_t extends CudaPointer { NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); int res = nativeOps.streamSynchronize(this); - if (nativeOps.lastErrorCode() != 0) - throw new RuntimeException(nativeOps.lastErrorMessage()); + val ec = nativeOps.lastErrorCode(); + if (ec != 0) + throw new RuntimeException(nativeOps.lastErrorMessage() + "; Error code: " + ec); return res; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java index 5548d854a..b08248bdb 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java @@ -129,7 +129,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer); - long requiredMemoryBytes = AllocationUtils.getRequiredMemory(point.getShape()); + long requiredMemoryBytes = point.getNumberOfBytes(); val originalBytes = requiredMemoryBytes; requiredMemoryBytes += 8 - (requiredMemoryBytes % 8); @@ -147,13 +147,13 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) { if (point.getAllocationStatus() == AllocationStatus.HOST && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) { - AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), - false); + //AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false); + throw new UnsupportedOperationException("Pew-pew"); } val profD = PerformanceTracker.getInstance().helperStartTransaction(); - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) { + if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) { throw new ND4JIllegalStateException("memcpyAsync failed"); } flowController.commitTransfer(context.getSpecialStream()); @@ -176,14 +176,13 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { if (currentOffset >= MAX_CONSTANT_LENGTH) { if (point.getAllocationStatus() == AllocationStatus.HOST && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) { - AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), - false); + //AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false); + throw new UnsupportedOperationException("Pew-pew"); } val profD = PerformanceTracker.getInstance().helperStartTransaction(); - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), - originalBytes, 1, context.getSpecialStream()) == 0) { + if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) { throw new ND4JIllegalStateException("memcpyAsync failed"); } flowController.commitTransfer(context.getSpecialStream()); @@ -202,8 +201,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getPointers().getHostPointer(), originalBytes, 1, - context.getSpecialStream()); + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getHostPointer(), originalBytes, 1, context.getSpecialStream()); flowController.commitTransfer(context.getSpecialStream()); long cAddr = deviceAddresses.get(deviceId).address() + currentOffset; @@ -212,7 +210,10 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { // logger.info("copying to constant: {}, bufferLength: {}, bufferDtype: {}, currentOffset: {}, currentAddres: {}", requiredMemoryBytes, dataBuffer.length(), dataBuffer.dataType(), currentOffset, cAddr); point.setAllocationStatus(AllocationStatus.CONSTANT); - point.getPointers().setDevicePointer(new CudaPointer(cAddr)); + //point.setDevicePointer(new CudaPointer(cAddr)); + if (1 > 0) + throw new UnsupportedOperationException("Pew-pew"); + point.setConstant(true); point.tickDeviceWrite(); point.setDeviceId(deviceId); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java index d81de381a..48e981491 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java @@ -32,6 +32,7 @@ import org.nd4j.jita.conf.Configuration; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.jita.flow.FlowController; import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; @@ -70,53 +71,12 @@ public class SynchronousFlowController implements FlowController { */ @Override public void synchronizeToHost(AllocationPoint point) { - - if (!point.isActualOnHostSide()) { - val context = allocator.getDeviceContext(); - - if (!point.isConstant()) - waitTillFinished(point); - - // if this piece of memory is device-dependant, we'll also issue copyback once - if (point.getAllocationStatus() == AllocationStatus.DEVICE && !point.isActualOnHostSide()) { - long perfD = PerformanceTracker.getInstance().helperStartTransaction(); - val bytes = AllocationUtils.getRequiredMemory(point.getShape()); - - if (nativeOps.memcpyAsync(point.getHostPointer(), point.getDevicePointer(), bytes, CudaConstants.cudaMemcpyDeviceToHost, context.getSpecialStream()) == 0) - throw new IllegalStateException("synchronizeToHost memcpyAsync failed: " + point.getShape()); - - commitTransfer(context.getSpecialStream()); - - PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.DEVICE_TO_HOST); - } - - // updating host read timer - point.tickHostRead(); - } + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(point.getPtrDataBuffer()); } @Override public void synchronizeToDevice(@NonNull AllocationPoint point) { - if (point.isConstant()) - return; - - if (!point.isActualOnDeviceSide()) { - if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - val context = allocator.getDeviceContext(); - - long perfD = PerformanceTracker.getInstance().helperStartTransaction(); - - if (nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(), - AllocationUtils.getRequiredMemory(point.getShape()), - CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()) == 0) - throw new IllegalStateException("MemcpyAsync failed: " + point.getShape()); - - commitTransfer(context.getSpecialStream()); - point.tickDeviceRead(); - - PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); - } - } + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(point.getPtrDataBuffer()); } @Override @@ -147,7 +107,6 @@ public class SynchronousFlowController implements FlowController { val pointData = allocator.getAllocationPoint(operand); val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer()); - pointData.acquireLock(); if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) { DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() @@ -172,15 +131,12 @@ public class SynchronousFlowController implements FlowController { val cId = allocator.getDeviceId(); - if (result != null && !result.isEmpty() && !result.isS()) { + if (result != null && !result.isEmpty()) { Nd4j.getCompressor().autoDecompress(result); prepareDelayedMemory(result); val pointData = allocator.getAllocationPoint(result); val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer()); - pointData.acquireLock(); - - if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) { DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data() : result.data().originalDataBuffer(); @@ -206,8 +162,7 @@ public class SynchronousFlowController implements FlowController { val pointData = allocator.getAllocationPoint(operand); val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer()); - - pointData.acquireLock(); + Nd4j.getAffinityManager().ensureLocation(operand, AffinityManager.Location.DEVICE); if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) { DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() @@ -240,14 +195,12 @@ public class SynchronousFlowController implements FlowController { eventsProvider.storeEvent(result.getLastWriteEvent()); result.setLastWriteEvent(eventsProvider.getEvent()); result.getLastWriteEvent().register(context.getOldStream()); - result.releaseLock(); for (AllocationPoint operand : operands) { eventsProvider.storeEvent(operand.getLastReadEvent()); operand.setLastReadEvent(eventsProvider.getEvent()); operand.getLastReadEvent().register(context.getOldStream()); - operand.releaseLock(); } // context.syncOldStream(); } @@ -263,7 +216,6 @@ public class SynchronousFlowController implements FlowController { eventsProvider.storeEvent(pointOperand.getLastWriteEvent()); pointOperand.setLastWriteEvent(eventsProvider.getEvent()); pointOperand.getLastWriteEvent().register(context.getOldStream()); - pointOperand.releaseLock(); } } @@ -276,14 +228,12 @@ public class SynchronousFlowController implements FlowController { eventsProvider.storeEvent(point.getLastWriteEvent()); point.setLastWriteEvent(eventsProvider.getEvent()); point.getLastWriteEvent().register(context.getOldStream()); - point.releaseLock(); for (INDArray operand : operands) { if (operand == null || operand.isEmpty()) continue; val pointOperand = allocator.getAllocationPoint(operand); - pointOperand.releaseLock(); eventsProvider.storeEvent(pointOperand.getLastReadEvent()); pointOperand.setLastReadEvent(eventsProvider.getEvent()); pointOperand.getLastReadEvent().register(context.getOldStream()); @@ -295,7 +245,6 @@ public class SynchronousFlowController implements FlowController { val context = allocator.getDeviceContext(); if (result != null) { - result.acquireLock(); result.setCurrentContext(context); } @@ -303,7 +252,6 @@ public class SynchronousFlowController implements FlowController { if (operand == null) continue; - operand.acquireLock(); operand.setCurrentContext(context); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index 9b8c1012c..f1cbf4958 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -16,6 +16,7 @@ package org.nd4j.jita.handler.impl; +import lombok.var; import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.Table; @@ -44,9 +45,6 @@ import org.nd4j.jita.flow.FlowController; import org.nd4j.jita.flow.impl.GridFlowController; import org.nd4j.jita.handler.MemoryHandler; import org.nd4j.jita.memory.MemoryProvider; -import org.nd4j.jita.memory.impl.CudaCachingZeroProvider; -import org.nd4j.jita.memory.impl.CudaDirectProvider; -import org.nd4j.jita.memory.impl.CudaFullCachingProvider; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -99,9 +97,6 @@ public class CudaZeroHandler implements MemoryHandler { private final AtomicBoolean wasInitialised = new AtomicBoolean(false); - @Getter - private final MemoryProvider memoryProvider; - private final FlowController flowController; private final AllocationStatus INITIAL_LOCATION; @@ -148,20 +143,6 @@ public class CudaZeroHandler implements MemoryHandler { throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]"); } - switch (configuration.getAllocationModel()) { - case CACHE_ALL: - this.memoryProvider = new CudaFullCachingProvider(); - break; - case CACHE_HOST: - this.memoryProvider = new CudaCachingZeroProvider(); - break; - case DIRECT: - this.memoryProvider = new CudaDirectProvider(); - break; - default: - throw new RuntimeException("Unknown AllocationModel: [" + configuration.getAllocationModel() + "]"); - } - int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices(); for (int i = 0; i < numDevices; i++) { deviceAllocations.add(new ConcurrentHashMap()); @@ -191,7 +172,7 @@ public class CudaZeroHandler implements MemoryHandler { int numBuckets = configuration.getNumberOfGcThreads(); long bucketId = RandomUtils.nextInt(0, numBuckets); - long reqMemory = AllocationUtils.getRequiredMemory(point.getShape()); + long reqMemory = point.getNumberOfBytes(); zeroUseCounter.addAndGet(reqMemory); @@ -221,130 +202,7 @@ public class CudaZeroHandler implements MemoryHandler { public PointersPair alloc(AllocationStatus targetMode, AllocationPoint point, AllocationShape shape, boolean initialize) { - long reqMemory = AllocationUtils.getRequiredMemory(shape); - val context = getCudaContext(); - switch (targetMode) { - case HOST: { - if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) { - - while (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) { - - val before = MemoryTracker.getInstance().getActiveHostAmount(); - memoryProvider.purgeCache(); - Nd4j.getMemoryManager().invokeGc(); - val after = MemoryTracker.getInstance().getActiveHostAmount(); - - log.debug("[HOST] before: {}; after: {};", before, after); - - if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) { - try { - log.warn("No available [HOST] memory, sleeping for a while... Consider increasing -Xmx next time."); - log.debug("Currently used: [" + zeroUseCounter.get() + "], allocated objects: [" + zeroAllocations.get(0) + "]"); - - memoryProvider.purgeCache(); - Nd4j.getMemoryManager().invokeGc(); - Thread.sleep(1000); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - } - } - - PointersPair pair = memoryProvider.malloc(shape, point, targetMode); - - if (initialize) { - org.bytedeco.javacpp.Pointer.memset(pair.getHostPointer(), 0, reqMemory); - point.tickHostWrite(); - } - - - pickupHostAllocation(point); - - return pair; - } - case DEVICE: { - int deviceId = getDeviceId(); - - PointersPair returnPair = new PointersPair(); - PointersPair tmpPair = new PointersPair(); - - if (point.getPointers() == null) - point.setPointers(tmpPair); - - if (deviceMemoryTracker.reserveAllocationIfPossible(Thread.currentThread().getId(), deviceId, reqMemory)) { - point.setDeviceId(deviceId); - val pair = memoryProvider.malloc(shape, point, targetMode); - if (pair != null) { - returnPair.setDevicePointer(pair.getDevicePointer()); - - point.setAllocationStatus(AllocationStatus.DEVICE); - - if (point.getPointers() == null) - throw new RuntimeException("PointersPair can't be null"); - - point.getPointers().setDevicePointer(pair.getDevicePointer()); - - deviceAllocations.get(deviceId).put(point.getObjectId(), point.getObjectId()); - - - val p = point.getBucketId(); - - if (p != null) { - val m = zeroAllocations.get(point.getBucketId()); - - // m can be null, if that's point from workspace - just no bucketId for it - if (m != null) - m.remove(point.getObjectId()); - } - - deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, reqMemory); - - if (!initialize) { - point.tickDeviceWrite(); - } else { - nativeOps.memsetAsync(pair.getDevicePointer(), 0, reqMemory, 0, context.getSpecialStream()); - context.getSpecialStream().synchronize(); - - point.tickDeviceWrite(); - } - } else { - log.warn("Out of [DEVICE] memory, host memory will be used instead: deviceId: [{}], requested bytes: [{}]; Approximate free bytes: {}; Real free bytes: {}", deviceId, reqMemory, MemoryTracker.getInstance().getApproximateFreeMemory(deviceId), MemoryTracker.getInstance().getPreciseFreeMemory(deviceId)); - log.info("Total allocated dev_0: {}", MemoryTracker.getInstance().getActiveMemory(0)); - log.info("Cached dev_0: {}", MemoryTracker.getInstance().getCachedAmount(0)); - log.info("Allocated dev_0: {}", MemoryTracker.getInstance().getAllocatedAmount(0)); - log.info("Workspace dev_0: {}", MemoryTracker.getInstance().getWorkspaceAllocatedAmount(0)); - //log.info("Total allocated dev_1: {}", MemoryTracker.getInstance().getActiveMemory(1)); - // if device memory allocation failed (aka returned NULL), keep using host memory instead - - returnPair.setDevicePointer(tmpPair.getHostPointer()); - - point.setAllocationStatus(AllocationStatus.HOST); - - Nd4j.getMemoryManager().invokeGc(); - try { - Thread.sleep(100); - } catch (Exception e) { - - } - } - } else { - log.warn("Hard limit on [DEVICE] memory hit, please consider tuning memory parameters, deviceId [{}]", - deviceId); - - Nd4j.getMemoryManager().invokeGc(); - try { - Thread.sleep(100); - } catch (InterruptedException e) { - // - } - } - - return returnPair; - } - default: - throw new IllegalStateException("Can't allocate memory on target [" + targetMode + "]"); - } + throw new UnsupportedOperationException(); } /** @@ -356,7 +214,7 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) { - return memoryProvider.pingDeviceForFreeMemory(deviceId, requiredMemory); + return true; } /** @@ -371,47 +229,7 @@ public class CudaZeroHandler implements MemoryHandler { @Override public void relocate(AllocationStatus currentStatus, AllocationStatus targetStatus, AllocationPoint point, AllocationShape shape, CudaContext context) { - //log.info("RELOCATE CALLED: [" +currentStatus+ "] -> ["+targetStatus+"]"); - if (currentStatus == AllocationStatus.DEVICE && targetStatus == AllocationStatus.HOST) { - // DEVICE -> HOST - DataBuffer targetBuffer = point.getBuffer(); - if (targetBuffer == null) - throw new IllegalStateException("Target buffer is NULL!"); - - Pointer devicePointer = new CudaPointer(point.getPointers().getDevicePointer().address()); - - } else if (currentStatus == AllocationStatus.HOST && targetStatus == AllocationStatus.DEVICE) { - // HOST -> DEVICE - - - // TODO: this probably should be removed - if (point.isConstant()) { - //log.info("Skipping relocation for constant"); - return; - } - - if (point.getPointers().getDevicePointer() == null) { - throw new IllegalStateException("devicePointer is NULL!"); - } - - val profD = PerformanceTracker.getInstance().helperStartTransaction(); - - if (nativeOps.memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), - AllocationUtils.getRequiredMemory(shape), CudaConstants.cudaMemcpyHostToDevice, - context.getSpecialStream()) == 0) - throw new IllegalStateException("MemcpyAsync relocate H2D failed: [" + point.getHostPointer().address() - + "] -> [" + point.getDevicePointer().address() + "]"); - - flowController.commitTransfer(context.getSpecialStream()); - - PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); - - //context.syncOldStream(); - - } else - throw new UnsupportedOperationException("Can't relocate data in requested direction: [" + currentStatus - + "] -> [" + targetStatus + "]"); } /** @@ -440,11 +258,6 @@ public class CudaZeroHandler implements MemoryHandler { @Override @Deprecated public void copyforward(AllocationPoint point, AllocationShape shape) { - /* - Technically that's just a case for relocate, with source as HOST and target point.getAllocationStatus() - */ - log.info("copyforward() called on tp[" + point.getObjectId() + "], shape: " + point.getShape()); - //relocate(AllocationStatus.HOST, point.getAllocationStatus(), point, shape); throw new UnsupportedOperationException("Deprecated call"); } @@ -467,15 +280,7 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public void free(AllocationPoint point, AllocationStatus target) { - //if (point.getAllocationStatus() == AllocationStatus.DEVICE) - //deviceAllocations.get(point.getDeviceId()).remove(point.getObjectId()); - //zeroAllocations.get(point.getBucketId()).remove(point.getObjectId()); - if (point.getAllocationStatus() == AllocationStatus.DEVICE) - deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), point.getDeviceId(), - AllocationUtils.getRequiredMemory(point.getShape())); - - memoryProvider.free(point); } /** @@ -525,7 +330,7 @@ public class CudaZeroHandler implements MemoryHandler { CudaContext tContext = null; if (dstBuffer.isConstant()) { - org.bytedeco.javacpp.Pointer dstPointer = new CudaPointer(point.getPointers().getHostPointer().address() + dstOffset, 0L); + org.bytedeco.javacpp.Pointer dstPointer = new CudaPointer(point.getHostPointer().address() + dstOffset, 0L); org.bytedeco.javacpp.Pointer srcPointerJ = new CudaPointer(srcPointer, length); val profD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -534,14 +339,34 @@ public class CudaZeroHandler implements MemoryHandler { point.tickHostRead(); } else { + // if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well + Pointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset); + + if (tContext == null) + tContext = flowController.prepareAction(point); + + var prof = PerformanceTracker.getInstance().helperStartTransaction(); + + flowController.commitTransfer(tContext.getSpecialStream()); + + if (nativeOps.memcpyAsync(rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, tContext.getSpecialStream()) == 0) + throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]"); + + flowController.commitTransfer(tContext.getSpecialStream()); + + PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); + + flowController.registerAction(tContext, point); + point.tickDeviceWrite(); + // we optionally copy to host memory - if (point.getPointers().getHostPointer() != null) { - Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset); + if (point.getHostPointer() != null) { + Pointer dP = new CudaPointer((point.getHostPointer().address()) + dstOffset); CudaContext context = flowController.prepareAction(point); tContext = context; - val prof = PerformanceTracker.getInstance().helperStartTransaction(); + prof = PerformanceTracker.getInstance().helperStartTransaction(); if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()) == 0) throw new IllegalStateException("MemcpyAsync H2H failed: [" + srcPointer.address() + "] -> [" + dP.address() + "]"); @@ -552,28 +377,10 @@ public class CudaZeroHandler implements MemoryHandler { if (point.getAllocationStatus() == AllocationStatus.HOST) flowController.registerAction(context, point); + + point.tickHostRead(); } } - - // if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well - if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset); - - if (tContext == null) - tContext = flowController.prepareAction(point); - - val prof = PerformanceTracker.getInstance().helperStartTransaction(); - - if (nativeOps.memcpyAsync(rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, tContext.getSpecialStream()) == 0) - throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]"); - - flowController.commitTransfer(tContext.getSpecialStream()); - - PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_DEVICE); - - flowController.registerAction(tContext, point); - point.tickDeviceWrite(); - } } @Override @@ -581,7 +388,7 @@ public class CudaZeroHandler implements MemoryHandler { CudaContext context) { AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); - Pointer dP = new CudaPointer((point.getPointers().getDevicePointer().address()) + dstOffset); + Pointer dP = new CudaPointer((point.getDevicePointer().address()) + dstOffset); if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0) throw new ND4JIllegalStateException("memcpyAsync failed"); @@ -604,7 +411,7 @@ public class CudaZeroHandler implements MemoryHandler { CudaContext context = getCudaContext(); AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); - Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset); + Pointer dP = new CudaPointer((point.getHostPointer().address()) + dstOffset); val profH = PerformanceTracker.getInstance().helperStartTransaction(); @@ -614,7 +421,7 @@ public class CudaZeroHandler implements MemoryHandler { PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profH, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_HOST); if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset); + Pointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset); val profD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -717,23 +524,22 @@ public class CudaZeroHandler implements MemoryHandler { @Override public org.bytedeco.javacpp.Pointer getDevicePointer(DataBuffer buffer, CudaContext context) { // TODO: It would be awesome to get rid of typecasting here - //getCudaContext().syncOldStream(); AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint(); // if that's device state, we probably might want to update device memory state if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) { if (!dstPoint.isActualOnDeviceSide()) { - // log.info("Relocating to GPU"); - relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context); + //relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context); + throw new UnsupportedOperationException("Pew-pew"); } } - // we update memory use counter, to announce that it's somehow used on device - dstPoint.tickDeviceRead(); + if (dstPoint.getDevicePointer() == null) + return null; - // return pointer with offset if needed. length is specified for constructor compatibility purposes - val p = new CudaPointer(dstPoint.getPointers().getDevicePointer(), buffer.length(), - (buffer.offset() * buffer.getElementSize())); + + // return pointer. length is specified for constructor compatibility purposes. Offset is accounted at C++ side + val p = new CudaPointer(dstPoint.getDevicePointer(), buffer.length(), 0); if (OpProfiler.getInstance().getConfig().isCheckLocality()) NativeOpsHolder.getInstance().getDeviceNativeOps().tryPointer(context.getOldStream(), p, 1); @@ -749,10 +555,17 @@ public class CudaZeroHandler implements MemoryHandler { case SHORT: case UINT16: case HALF: + case BFLOAT16: return p.asShortPointer(); case UINT64: case LONG: return p.asLongPointer(); + case UTF8: + case UBYTE: + case BYTE: + return p.asBytePointer(); + case BOOL: + return p.asBooleanPointer(); default: return p; } @@ -769,17 +582,14 @@ public class CudaZeroHandler implements MemoryHandler { AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint(); // return pointer with offset if needed. length is specified for constructor compatibility purposes - if (dstPoint.getPointers().getHostPointer() == null) { + if (dstPoint.getHostPointer() == null) { return null; } - //dstPoint.tickHostWrite(); - //dstPoint.tickHostRead(); - //log.info("Requesting host pointer for {}", buffer); - //getCudaContext().syncOldStream(); + synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint); - CudaPointer p = new CudaPointer(dstPoint.getPointers().getHostPointer(), buffer.length(), - (buffer.offset() * buffer.getElementSize())); + CudaPointer p = new CudaPointer(dstPoint.getHostPointer(), buffer.length(), 0); + switch (buffer.dataType()) { case DOUBLE: return p.asDoublePointer(); @@ -805,6 +615,9 @@ public class CudaZeroHandler implements MemoryHandler { public synchronized void relocateObject(DataBuffer buffer) { AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer); + if (1 > 0) + throw new UnsupportedOperationException("Pew-pew"); + // we don't relocate non-DEVICE buffers (i.e HOST or CONSTANT) if (dstPoint.getAllocationStatus() != AllocationStatus.DEVICE) return; @@ -838,14 +651,14 @@ public class CudaZeroHandler implements MemoryHandler { // if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually // host part is optional if (dstPoint.getHostPointer() != null) { - val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false); - dstPoint.getPointers().setHostPointer(pairH.getHostPointer()); + //val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false); + //dstPoint.getPointers().setHostPointer(pairH.getHostPointer()); } - val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false); - dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer()); + //val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false); + //dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer()); - //log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address()); + ////log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address()); CudaContext context = getCudaContext(); @@ -876,10 +689,10 @@ public class CudaZeroHandler implements MemoryHandler { Nd4j.getMemoryManager().memcpy(nBuffer, buffer); - dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer()); + //dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer()); if (dstPoint.getHostPointer() != null) { - dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer()); + // dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer()); } dstPoint.setDeviceId(deviceId); @@ -908,11 +721,10 @@ public class CudaZeroHandler implements MemoryHandler { context.syncSpecialStream(); } - memoryProvider.free(dstPoint); - deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape())); + //deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape())); // we replace original device pointer with new one - alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false); + //alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false); val profD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -940,6 +752,9 @@ public class CudaZeroHandler implements MemoryHandler { public boolean promoteObject(DataBuffer buffer) { AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer); + if (1 > 0) + throw new UnsupportedOperationException("Pew-pew"); + if (dstPoint.getAllocationStatus() != AllocationStatus.HOST) return false; @@ -952,20 +767,19 @@ public class CudaZeroHandler implements MemoryHandler { Nd4j.getConstantHandler().moveToConstantSpace(buffer); } else { - PointersPair pair = memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE); + PointersPair pair = null; //memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE); if (pair != null) { Integer deviceId = getDeviceId(); // log.info("Promoting object to device: [{}]", deviceId); - dstPoint.getPointers().setDevicePointer(pair.getDevicePointer()); + //dstPoint.setDevicePointer(pair.getDevicePointer()); dstPoint.setAllocationStatus(AllocationStatus.DEVICE); deviceAllocations.get(deviceId).put(dstPoint.getObjectId(), dstPoint.getObjectId()); zeroAllocations.get(dstPoint.getBucketId()).remove(dstPoint.getObjectId()); - deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, - AllocationUtils.getRequiredMemory(dstPoint.getShape())); + //deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, AllocationUtils.getRequiredMemory(dstPoint.getShape())); dstPoint.tickHostWrite(); @@ -1103,7 +917,7 @@ public class CudaZeroHandler implements MemoryHandler { if (deviceAllocations.get(deviceId).containsKey(objectId)) throw new IllegalStateException("Can't happen ever"); - deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape())); + //deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape())); point.setAllocationStatus(AllocationStatus.HOST); @@ -1119,6 +933,9 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) { + if (1 > 0) + throw new UnsupportedOperationException("Pew-pew"); + forget(point, AllocationStatus.HOST); flowController.waitTillReleased(point); @@ -1127,8 +944,8 @@ public class CudaZeroHandler implements MemoryHandler { if (point.getHostPointer() != null) { free(point, AllocationStatus.HOST); - long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1; - zeroUseCounter.addAndGet(reqMem); + //long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1; + //zeroUseCounter.addAndGet(reqMem); } point.setAllocationStatus(AllocationStatus.DEALLOCATED); @@ -1252,4 +1069,9 @@ public class CudaZeroHandler implements MemoryHandler { public FlowController getFlowController() { return flowController; } + + @Override + public MemoryProvider getMemoryProvider() { + return null; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java index da36da6db..ad820c109 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java @@ -147,7 +147,7 @@ public class CudaMemoryManager extends BasicMemoryManager { // Nd4j.getShapeInfoProvider().purgeCache(); // purge memory cache - AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache(); + //AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaCachingZeroProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaCachingZeroProvider.java deleted file mode 100644 index 1ba6bf34a..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaCachingZeroProvider.java +++ /dev/null @@ -1,303 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.memory.impl; - -import lombok.val; -import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.impl.AllocationPoint; -import org.nd4j.jita.allocator.impl.AllocationShape; -import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.pointers.PointersPair; -import org.nd4j.jita.allocator.utils.AllocationUtils; -import org.nd4j.jita.conf.Configuration; -import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.jita.memory.MemoryProvider; -import org.slf4j.Logger; -import org.nd4j.linalg.factory.Nd4j; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.List; -import java.util.Queue; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.Semaphore; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import org.nd4j.jita.allocator.impl.MemoryTracker; - - -/** - * This is MemoryProvider implementation, that adds cache for memory reuse purposes. Only host memory is cached for future reuse. - * - * If some memory chunk gets released via allocator, it'll be probably saved for future reused within same JVM process. - * - * @author raver119@gmail.com - */ -public class CudaCachingZeroProvider extends CudaDirectProvider implements MemoryProvider { - private static Logger log = LoggerFactory.getLogger(CudaCachingZeroProvider.class); - - protected volatile ConcurrentHashMap zeroCache = new ConcurrentHashMap<>(); - - protected final AtomicLong cacheZeroHit = new AtomicLong(0); - protected final AtomicLong cacheZeroMiss = new AtomicLong(0); - - protected final AtomicLong cacheDeviceHit = new AtomicLong(0); - protected final AtomicLong cacheDeviceMiss = new AtomicLong(0); - - - - private final AtomicLong allocRequests = new AtomicLong(0); - - protected final AtomicLong zeroCachedAmount = new AtomicLong(0); - protected List deviceCachedAmount = new ArrayList<>(); - - - protected final Semaphore singleLock = new Semaphore(1); - - // we don't cache allocations greater then this value - //protected final long MAX_SINGLE_ALLOCATION = configuration.getMaximumHostCacheableLength(); - - // maximum cached size of memory - //protected final long MAX_CACHED_MEMORY = configuration.getMaximumHostCache(); - - // memory chunks below this threshold will be guaranteed regardless of number of cache entries - // that especially covers all possible variations of shapeInfoDataBuffers in all possible cases - protected final long FORCED_CACHE_THRESHOLD = 96; - - // number of preallocation entries for each yet-unknown shape - //protected final int PREALLOCATION_LIMIT = configuration.getPreallocationCalls(); - - public CudaCachingZeroProvider() { - - } - - /** - * This method provides PointersPair to memory chunk specified by AllocationShape - * - * PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape. - * - * @param shape shape of desired memory chunk - * @param point target AllocationPoint structure - * @param location either HOST or DEVICE - * @return - */ - @Override - public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) { - long reqMemory = AllocationUtils.getRequiredMemory(shape); - - if (location == AllocationStatus.HOST && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength()) { - - val cache = zeroCache.get(shape); - if (cache != null) { - val pointer = cache.poll(); - if (pointer != null) { - cacheZeroHit.incrementAndGet(); - - // since this memory chunk is going to be used now, remove it's amount from - zeroCachedAmount.addAndGet(-1 * reqMemory); - - val pair = new PointersPair(); - pair.setDevicePointer(new CudaPointer(pointer.address())); - pair.setHostPointer(new CudaPointer(pointer.address())); - - point.setAllocationStatus(AllocationStatus.HOST); - - MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMemory); - MemoryTracker.getInstance().decrementCachedHostAmount(reqMemory); - - return pair; - } - } - cacheZeroMiss.incrementAndGet(); - - if (CudaEnvironment.getInstance().getConfiguration().isUsePreallocation() && zeroCachedAmount.get() < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache() / 10 - && reqMemory < 16 * 1024 * 1024L) { - val preallocator = new CachePreallocator(shape, location, CudaEnvironment.getInstance().getConfiguration().getPreallocationCalls()); - preallocator.start(); - } - - cacheZeroMiss.incrementAndGet(); - return super.malloc(shape, point, location); - } - - return super.malloc(shape, point, location); - } - - - - protected void ensureCacheHolder(AllocationShape shape) { - if (!zeroCache.containsKey(shape)) { - try { - singleLock.acquire(); - if (!zeroCache.containsKey(shape)) { - zeroCache.put(shape, new CacheHolder(shape, zeroCachedAmount)); - } - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - singleLock.release(); - } - } - - } - - /** - * This method frees specific chunk of memory, described by AllocationPoint passed in. - * - * PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse. - * - * @param point - */ - @Override - public void free(AllocationPoint point) { - if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - super.free(point); - } else { - // if this point has no allocated chunk - step over it - if (point.getHostPointer() == null) - return; - - AllocationShape shape = point.getShape(); - long reqMemory = AllocationUtils.getRequiredMemory(shape); - - // we don't cache too big objects - if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength() || zeroCachedAmount.get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache()) { - super.free(point); - return; - } - - ensureCacheHolder(shape); - - /* - Now we should decide if this object can be cached or not - */ - CacheHolder cache = zeroCache.get(shape); - - // memory chunks < threshold will be cached no matter what - if (reqMemory <= FORCED_CACHE_THRESHOLD) { - Pointer.memset(point.getHostPointer(), 0, reqMemory); - cache.put(new CudaPointer(point.getHostPointer().address())); - } else { - long cacheEntries = cache.size(); - long cacheHeight = zeroCache.size(); - - // total memory allocated within this bucket - long cacheDepth = cacheEntries * reqMemory; - - Pointer.memset(point.getHostPointer(), 0, reqMemory); - cache.put(new CudaPointer(point.getHostPointer().address())); - } - - MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMemory); - MemoryTracker.getInstance().incrementCachedHostAmount(reqMemory); - } - } - - private float getZeroCacheHitRatio() { - long totalHits = cacheZeroHit.get() + cacheZeroMiss.get(); - float cacheRatio = cacheZeroHit.get() * 100 / (float) totalHits; - return cacheRatio; - } - - private float getDeviceCacheHitRatio() { - long totalHits = cacheDeviceHit.get() + cacheDeviceMiss.get(); - float cacheRatio = cacheDeviceHit.get() * 100 / (float) totalHits; - return cacheRatio; - } - - @Deprecated - public void printCacheStats() { - log.debug("Cached host amount: " + zeroCachedAmount.get()); - log.debug("Cached device amount: " + deviceCachedAmount.get(0).get()); - log.debug("Total shapes in cache: " + zeroCache.size()); - log.debug("Current host hit ratio: " + getZeroCacheHitRatio()); - log.debug("Current device hit ratio: " + getDeviceCacheHitRatio()); - } - - protected class CacheHolder { - private Queue queue = new ConcurrentLinkedQueue<>(); - private volatile int counter = 0; - private long reqMem = 0; - private final AtomicLong allocCounter; - - public CacheHolder(AllocationShape shape, AtomicLong counter) { - this.reqMem = AllocationUtils.getRequiredMemory(shape); - this.allocCounter = counter; - } - - public synchronized int size() { - return counter; - } - - public synchronized Pointer poll() { - val pointer = queue.poll(); - if (pointer != null) - counter--; - - return pointer; - } - - public synchronized void put(Pointer pointer) { - allocCounter.addAndGet(reqMem); - counter++; - queue.add(pointer); - } - } - - protected class CachePreallocator extends Thread implements Runnable { - - private AllocationShape shape; - private AllocationStatus location; - private int target; - - public CachePreallocator(AllocationShape shape, AllocationStatus location, int numberOfEntries) { - this.shape = shape; - this.target = numberOfEntries; - this.location = location; - } - - @Override - public void run() { - ensureCacheHolder(shape); - - for (int i = 0; i < target; i++) { - val point = new AllocationPoint(); - - val pair = CudaCachingZeroProvider.super.malloc(shape, point, this.location); - if (this.location == AllocationStatus.HOST) { - Pointer pointer = new CudaPointer(pair.getHostPointer().address()); - CudaCachingZeroProvider.this.zeroCache.get(shape).put(pointer); - } - } - } - } - - @Override - public void purgeCache() { - for (AllocationShape shape : zeroCache.keySet()) { - Pointer ptr = null; - while ((ptr = zeroCache.get(shape).poll()) != null) { - freeHost(ptr); - MemoryTracker.getInstance().decrementCachedHostAmount(shape.getNumberOfBytes()); - } - } - - zeroCachedAmount.set(0); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java deleted file mode 100644 index eba4d74d0..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java +++ /dev/null @@ -1,239 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.memory.impl; - -import lombok.val; -import lombok.var; -import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.impl.AllocationPoint; -import org.nd4j.jita.allocator.impl.AllocationShape; -import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.pointers.PointersPair; -import org.nd4j.jita.allocator.utils.AllocationUtils; -import org.nd4j.jita.memory.MemoryProvider; -import org.nd4j.linalg.api.memory.AllocationsTracker; -import org.nd4j.linalg.api.memory.enums.AllocationKind; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.nativeblas.NativeOps; -import org.nd4j.nativeblas.NativeOpsHolder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.nd4j.jita.allocator.impl.MemoryTracker; - -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - -/** - * @author raver119@gmail.com - */ -public class CudaDirectProvider implements MemoryProvider { - - protected static final long DEVICE_RESERVED_SPACE = 1024 * 1024 * 50L; - private static Logger log = LoggerFactory.getLogger(CudaDirectProvider.class); - protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - - protected volatile ConcurrentHashMap validator = new ConcurrentHashMap<>(); - - - private AtomicLong emergencyCounter = new AtomicLong(0); - - /** - * This method provides PointersPair to memory chunk specified by AllocationShape - * - * @param shape shape of desired memory chunk - * @param point target AllocationPoint structure - * @param location either HOST or DEVICE - * @return - */ - @Override - public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) { - - //log.info("shape onCreate: {}, target: {}", shape, location); - - switch (location) { - case HOST: { - long reqMem = AllocationUtils.getRequiredMemory(shape); - - // FIXME: this is WRONG, and directly leads to memleak - if (reqMem < 1) - reqMem = 1; - - val pointer = nativeOps.mallocHost(reqMem, 0); - if (pointer == null) - throw new RuntimeException("Can't allocate [HOST] memory: " + reqMem + "; threadId: " - + Thread.currentThread().getId()); - - // log.info("Host allocation, Thread id: {}, ReqMem: {}, Pointer: {}", Thread.currentThread().getId(), reqMem, pointer != null ? pointer.address() : null); - - val hostPointer = new CudaPointer(pointer); - - val devicePointerInfo = new PointersPair(); - if (point.getPointers().getDevicePointer() == null) { - point.setAllocationStatus(AllocationStatus.HOST); - devicePointerInfo.setDevicePointer(new CudaPointer(hostPointer, reqMem)); - } else - devicePointerInfo.setDevicePointer(point.getDevicePointer()); - - devicePointerInfo.setHostPointer(new CudaPointer(hostPointer, reqMem)); - - point.setPointers(devicePointerInfo); - - MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMem); - - return devicePointerInfo; - } - case DEVICE: { - // cudaMalloc call - val deviceId = AtomicAllocator.getInstance().getDeviceId(); - long reqMem = AllocationUtils.getRequiredMemory(shape); - - // FIXME: this is WRONG, and directly leads to memleak - if (reqMem < 1) - reqMem = 1; - - AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, deviceId, reqMem); - var pointer = nativeOps.mallocDevice(reqMem, deviceId, 0); - if (pointer == null) { - // try to purge stuff if we're low on memory - purgeCache(deviceId); - - // call for gc - Nd4j.getMemoryManager().invokeGc(); - - pointer = nativeOps.mallocDevice(reqMem, deviceId, 0); - if (pointer == null) - return null; - } - - val devicePointer = new CudaPointer(pointer); - - var devicePointerInfo = point.getPointers(); - if (devicePointerInfo == null) - devicePointerInfo = new PointersPair(); - devicePointerInfo.setDevicePointer(new CudaPointer(devicePointer, reqMem)); - - point.setAllocationStatus(AllocationStatus.DEVICE); - point.setDeviceId(deviceId); - MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMem); - return devicePointerInfo; - } - default: - throw new IllegalStateException("Unsupported location for malloc: [" + location + "]"); - } - } - - /** - * This method frees specific chunk of memory, described by AllocationPoint passed in - * - * @param point - */ - @Override - public void free(AllocationPoint point) { - switch (point.getAllocationStatus()) { - case HOST: { - // cudaFreeHost call here - long reqMem = AllocationUtils.getRequiredMemory(point.getShape()); - val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - - long result = nativeOps.freeHost(point.getPointers().getHostPointer()); - if (result == 0) { - throw new RuntimeException("Can't deallocate [HOST] memory..."); - } - - MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMem); - } - break; - case DEVICE: { - if (point.isConstant()) - return; - - long reqMem = AllocationUtils.getRequiredMemory(point.getShape()); - - val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, point.getDeviceId(), reqMem); - - val pointers = point.getPointers(); - - long result = nativeOps.freeDevice(pointers.getDevicePointer(), 0); - if (result == 0) - throw new RuntimeException("Can't deallocate [DEVICE] memory..."); - - MemoryTracker.getInstance().decrementAllocatedAmount(point.getDeviceId(), reqMem); - } - break; - default: - throw new IllegalStateException("Can't free memory on target [" + point.getAllocationStatus() + "]"); - } - } - - /** - * This method checks specified device for specified amount of memory - * - * @param deviceId - * @param requiredMemory - * @return - */ - public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) { - /* - long[] totalMem = new long[1]; - long[] freeMem = new long[1]; - - - JCuda.cudaMemGetInfo(freeMem, totalMem); - - long free = freeMem[0]; - long total = totalMem[0]; - long used = total - free; - - /* - We don't want to allocate memory if it's too close to the end of available ram. - */ - //if (configuration != null && used > total * configuration.getMaxDeviceMemoryUsed()) return false; - - /* - if (free + requiredMemory < total * 0.85) - return true; - else return false; - */ - long freeMem = nativeOps.getDeviceFreeMemory(-1); - if (freeMem - requiredMemory < DEVICE_RESERVED_SPACE) - return false; - else - return true; - } - - protected void freeHost(Pointer pointer) { - val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - nativeOps.freeHost(pointer); - } - - protected void freeDevice(Pointer pointer, int deviceId) { - val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - nativeOps.freeDevice(pointer, 0); - } - - protected void purgeCache(int deviceId) { - // - } - - @Override - public void purgeCache() { - // no-op - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaFullCachingProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaFullCachingProvider.java deleted file mode 100644 index 2157dfb56..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaFullCachingProvider.java +++ /dev/null @@ -1,220 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.memory.impl; - -import lombok.val; -import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.impl.AllocationPoint; -import org.nd4j.jita.allocator.impl.AllocationShape; -import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.jita.allocator.impl.MemoryTracker; -import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.pointers.PointersPair; -import org.nd4j.jita.allocator.utils.AllocationUtils; -import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.linalg.factory.Nd4j; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - -/** - * This MemoryProvider implementation does caching for both host and device memory within predefined limits. - * - * @author raver119@gmail.com - */ -public class CudaFullCachingProvider extends CudaCachingZeroProvider { - - //protected final long MAX_GPU_ALLOCATION = configuration.getMaximumSingleDeviceAllocation(); - - //protected final long MAX_GPU_CACHE = configuration.getMaximumDeviceCache(); - - - protected volatile ConcurrentHashMap> deviceCache = - new ConcurrentHashMap<>(); - - - private static Logger log = LoggerFactory.getLogger(CudaFullCachingProvider.class); - - public CudaFullCachingProvider() { - - init(); - } - - public void init() { - int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); - - deviceCachedAmount = new ArrayList<>(); - - for (int i = 0; i < numDevices; i++) { - deviceCachedAmount.add(new AtomicLong(0)); - } - } - - /** - * This method provides PointersPair to memory chunk specified by AllocationShape - * - * PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape. - * - * @param shape shape of desired memory chunk - * @param point target AllocationPoint structure - * @param location either HOST or DEVICE - * @return - */ - @Override - public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) { - val reqMemory = AllocationUtils.getRequiredMemory(shape); - if (location == AllocationStatus.DEVICE && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceAllocation()) { - - - val deviceId = AtomicAllocator.getInstance().getDeviceId(); - ensureDeviceCacheHolder(deviceId, shape); - - val cache = deviceCache.get(deviceId).get(shape); - if (cache != null) { - val pointer = cache.poll(); - if (pointer != null) { - cacheDeviceHit.incrementAndGet(); - - deviceCachedAmount.get(deviceId).addAndGet(-reqMemory); - - val pair = new PointersPair(); - pair.setDevicePointer(pointer); - - point.setAllocationStatus(AllocationStatus.DEVICE); - point.setDeviceId(deviceId); - - - MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMemory); - MemoryTracker.getInstance().decrementCachedAmount(deviceId, reqMemory); - - return pair; - } - } - cacheDeviceMiss.incrementAndGet(); - return super.malloc(shape, point, location); - } - return super.malloc(shape, point, location); - } - - /** - * This method frees specific chunk of memory, described by AllocationPoint passed in - * - * PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse. - * - * @param point - */ - @Override - public void free(AllocationPoint point) { - if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - if (point.isConstant()) - return; - - val shape = point.getShape(); - val deviceId = point.getDeviceId(); - val address = point.getDevicePointer().address(); - val reqMemory = AllocationUtils.getRequiredMemory(shape); - // we don't cache too big objects - - if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCacheableLength() || deviceCachedAmount.get(deviceId).get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCache()) { - super.free(point); - return; - } - - ensureDeviceCacheHolder(deviceId, shape); - - val cache = deviceCache.get(deviceId).get(shape); - - if (point.getDeviceId() != deviceId) - throw new RuntimeException("deviceId changed!"); - - // memory chunks < threshold will be cached no matter what - if (reqMemory <= FORCED_CACHE_THRESHOLD) { - cache.put(new CudaPointer(point.getDevicePointer().address())); - MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory); - MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory); - return; - } else { - - cache.put(new CudaPointer(point.getDevicePointer().address())); - - MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory); - MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory); - return; - } - } - super.free(point); - } - - /** - * This method checks, if storage contains holder for specified shape - * - * @param deviceId - * @param shape - */ - protected void ensureDeviceCacheHolder(Integer deviceId, AllocationShape shape) { - if (!deviceCache.containsKey(deviceId)) { - try { - synchronized (this) { - if (!deviceCache.containsKey(deviceId)) { - deviceCache.put(deviceId, new ConcurrentHashMap()); - } - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - if (!deviceCache.get(deviceId).containsKey(shape)) { - try { - singleLock.acquire(); - - if (!deviceCache.get(deviceId).containsKey(shape)) { - deviceCache.get(deviceId).put(shape, new CacheHolder(shape, deviceCachedAmount.get(deviceId))); - } - } catch (Exception e) { - - } finally { - singleLock.release(); - } - } - } - - @Override - protected synchronized void purgeCache(int deviceId) { - for (AllocationShape shape : deviceCache.get(deviceId).keySet()) { - Pointer ptr = null; - while ((ptr = deviceCache.get(deviceId).get(shape).poll()) != null) { - freeDevice(ptr, deviceId); - MemoryTracker.getInstance().decrementCachedAmount(deviceId, shape.getNumberOfBytes()); - } - } - - deviceCachedAmount.get(deviceId).set(0); - } - - @Override - public synchronized void purgeCache() { - for (Integer device : deviceCache.keySet()) { - purgeCache(device); - } - super.purgeCache(); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java index 34970dc19..2f3ad94df 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java @@ -20,10 +20,12 @@ import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.Loader; import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.environment.Nd4jEnvironment; +import org.nd4j.linalg.factory.Environment; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.Resource; +import org.nd4j.nativeblas.CudaEnvironment; import org.nd4j.nativeblas.Nd4jCuda; import java.util.List; @@ -86,6 +88,11 @@ public class JCublasBackend extends Nd4jBackend { return JCublasNDArray.class; } + @Override + public Environment getEnvironment() { + return CudaEnvironment.getInstance(); + } + @Override public void logBackendInit() { String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true"); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java index 79d87a01e..df44adb17 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java @@ -17,34 +17,39 @@ package org.nd4j.linalg.jcublas; +import com.google.flatbuffers.FlatBufferBuilder; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.bytedeco.javacpp.BytePointer; +import org.nd4j.base.Preconditions; +import org.nd4j.graph.FlatArray; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.enums.CudaConstants; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.DataTypeEx; -import org.nd4j.linalg.api.buffer.FloatBuffer; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.BaseNDArray; import org.nd4j.linalg.api.ndarray.BaseNDArrayProxy; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.JvmShapeInfo; -import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer; +import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.workspace.WorkspaceUtils; import org.nd4j.nativeblas.NativeOpsHolder; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; import java.util.List; -import java.util.concurrent.atomic.AtomicLong; /** * @@ -387,10 +392,6 @@ public class JCublasNDArray extends BaseNDArray { super(data, order); } - public JCublasNDArray(FloatBuffer floatBuffer, char order) { - super(floatBuffer, order); - } - public JCublasNDArray(DataBuffer buffer, int[] shape, int[] strides) { super(buffer, shape, strides); } @@ -574,26 +575,16 @@ public class JCublasNDArray extends BaseNDArray { MemcpyDirection direction = MemcpyDirection.HOST_TO_HOST; val prof = PerformanceTracker.getInstance().helperStartTransaction(); - if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) { - // d2d copy + if (srcPoint.isActualOnDeviceSide()) { route = 1; NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, blocking ? context.getOldStream() : context.getSpecialStream()); dstPoint.tickDeviceWrite(); direction = MemcpyDirection.DEVICE_TO_DEVICE; - } else if (dstPoint.getAllocationStatus() == AllocationStatus.HOST && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) { - route = 2; - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToHost, blocking ? context.getOldStream() : context.getSpecialStream()); - dstPoint.tickHostWrite(); - direction = MemcpyDirection.DEVICE_TO_HOST; - } else if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.HOST) { + } else { route = 3; NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, blocking ? context.getOldStream() : context.getSpecialStream()); dstPoint.tickDeviceWrite(); direction = MemcpyDirection.HOST_TO_DEVICE; - } else { - route = 4; - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToHost, blocking ? context.getOldStream() : context.getSpecialStream()); - dstPoint.tickHostWrite(); } @@ -650,30 +641,16 @@ public class JCublasNDArray extends BaseNDArray { Nd4j.getMemoryManager().setCurrentWorkspace(target); -// log.info("Leveraging..."); - INDArray copy = null; if (!this.isView()) { - //if (1 < 0) { Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.length(), false); + val buffer = Nd4j.createBuffer(this.length(), false); - AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer); - AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data); + val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer); + val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc); -/* - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointDst.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0) - throw new ND4JIllegalStateException("memsetAsync 1 failed"); - - context.syncOldStream(); - - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointSrc.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0) - throw new ND4JIllegalStateException("memsetAsync 2 failed"); - - context.syncOldStream(); -*/ + val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc); MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; val perfD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -690,12 +667,11 @@ public class JCublasNDArray extends BaseNDArray { context.syncOldStream(); - PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); + PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), direction); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); // tag buffer as valid on device side - pointDst.tickHostRead(); pointDst.tickDeviceWrite(); AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc); @@ -728,6 +704,7 @@ public class JCublasNDArray extends BaseNDArray { val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer); val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data); + val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc); MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; @@ -764,6 +741,38 @@ public class JCublasNDArray extends BaseNDArray { return copy; } + protected int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer) { + Preconditions.checkArgument(buffer.dataType() == DataType.UTF8, "This method can be called on UTF8 buffers only"); + try { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(bos); + + val numWords = this.length(); + val ub = (CudaUtf8Buffer) buffer; + // writing length first + val t = length(); + val ptr = (BytePointer) ub.pointer(); + + // now write all strings as bytes + for (int i = 0; i < ub.length(); i++) { + dos.writeByte(ptr.get(i)); + } + + val bytes = bos.toByteArray(); + return FlatArray.createBufferVector(builder, bytes); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public String getString(long index) { + if (!isS()) + throw new UnsupportedOperationException("This method is usable only on String dataType, but got [" + this.dataType() + "]"); + + return ((CudaUtf8Buffer) data).getString(index); + } + /* @Override public INDArray convertToHalfs() { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java index 0bcb6e562..c529c4f7c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java @@ -18,11 +18,9 @@ package org.nd4j.linalg.jcublas; import lombok.extern.slf4j.Slf4j; import lombok.val; -import lombok.var; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataTypeEx; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.memory.enums.MemoryKind; import org.nd4j.linalg.api.ops.custom.Flatten; import org.nd4j.linalg.api.ops.impl.shape.Concat; @@ -34,12 +32,10 @@ import org.nd4j.linalg.jcublas.buffer.*; import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.primitives.Pair; import org.bytedeco.javacpp.*; -import org.bytedeco.javacpp.indexer.*; import org.nd4j.jita.allocator.enums.CudaConstants; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.utils.AllocationUtils; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -51,19 +47,12 @@ import org.nd4j.linalg.compression.CompressedDataBuffer; import org.nd4j.linalg.compression.CompressionDescriptor; import org.nd4j.linalg.compression.CompressionType; import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.BaseNDArrayFactory; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.blas.*; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.nativeblas.*; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.io.File; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.charset.Charset; import java.util.*; /** @@ -216,7 +205,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { @Override public INDArray create(Collection strings, long[] shape, char order) { val pairShape = Nd4j.getShapeInfoProvider().createShapeInformation(shape, order, DataType.UTF8); - val buffer = new Utf8Buffer(strings); + val buffer = new CudaUtf8Buffer(strings); val list = new ArrayList(strings); return Nd4j.createArrayFromShapeBuffer(buffer, pairShape); } @@ -360,8 +349,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { @Override public INDArray concat(int dimension, INDArray... toConcat) { - if (Nd4j.getExecutioner() instanceof GridExecutioner) - ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + Nd4j.getExecutioner().push(); return Nd4j.exec(new Concat(dimension, toConcat))[0]; } @@ -517,9 +505,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { AtomicAllocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareAction(ret, source); - Pointer x = AtomicAllocator.getInstance().getPointer(source, context); + val x = ((BaseCudaDataBuffer) source.data()).getOpaqueDataBuffer(); + val z = ((BaseCudaDataBuffer) ret.data()).getOpaqueDataBuffer(); Pointer xShape = AtomicAllocator.getInstance().getPointer(source.shapeInfoDataBuffer(), context); - Pointer z = AtomicAllocator.getInstance().getPointer(ret, context); Pointer zShape = AtomicAllocator.getInstance().getPointer(ret.shapeInfoDataBuffer(), context); PointerPointer extras = new PointerPointer(AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()), @@ -545,14 +533,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { nativeOps.pullRows(extras, - null, - (LongPointer) source.shapeInfoDataBuffer().addressPointer(), - x, - (LongPointer) xShape, - null, - (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - z, - (LongPointer) zShape, + x, (LongPointer) source.shapeInfoDataBuffer().addressPointer(), (LongPointer) xShape, + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), (LongPointer) zShape, indexes.length, (LongPointer) pIndex, (LongPointer) tadShapeInfo, @@ -601,7 +583,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); AllocationPoint point = allocator.getAllocationPoint(arrays[i]); - xPointers[i] = point.getPointers().getDevicePointer().address(); + xPointers[i] = point.getDevicePointer().address(); point.tickDeviceWrite(); } @@ -710,7 +692,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); AllocationPoint point = allocator.getAllocationPoint(arrays[i]); - xPointers[i] = point.getPointers().getDevicePointer().address(); + xPointers[i] = point.getDevicePointer().address(); point.tickDeviceWrite(); } @@ -1324,11 +1306,11 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { PointerPointer extraz = new PointerPointer(null, // not used context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()); + val x = ((BaseCudaDataBuffer) tensor.data()).getOpaqueDataBuffer(); + + nativeOps.tear(extraz, - null, - (LongPointer) tensor.shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(tensor, context), - (LongPointer) AtomicAllocator.getInstance().getPointer(tensor.shapeInfoDataBuffer(), context), + x, (LongPointer) tensor.shapeInfoDataBuffer().addressPointer(), (LongPointer) AtomicAllocator.getInstance().getPointer(tensor.shapeInfoDataBuffer(), context), new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)), (LongPointer) AtomicAllocator.getInstance().getPointer(result[0].shapeInfoDataBuffer(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context), diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 5e0583d56..6c82cf1de 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -21,6 +21,7 @@ import lombok.NonNull; import lombok.val; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.*; +import org.nd4j.base.Preconditions; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.enums.CudaConstants; import org.nd4j.jita.allocator.impl.AllocationPoint; @@ -38,6 +39,8 @@ import org.nd4j.linalg.api.memory.Deallocatable; import org.nd4j.linalg.api.memory.Deallocator; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.enums.MemoryKind; +import org.nd4j.linalg.api.memory.enums.MirroringPolicy; +import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -47,7 +50,9 @@ import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.memory.abstracts.DummyWorkspace; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.LongUtils; +import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -74,6 +79,7 @@ import java.util.Collection; * @author raver119@gmail.com */ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCudaBuffer, Deallocatable { + protected OpaqueDataBuffer ptrDataBuffer; @Getter protected transient volatile AllocationPoint allocationPoint; @@ -88,10 +94,12 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda } + public OpaqueDataBuffer getOpaqueDataBuffer() { + return ptrDataBuffer; + } + + public BaseCudaDataBuffer(@NonNull Pointer pointer, @NonNull Pointer specialPointer, @NonNull Indexer indexer, long length) { - this.allocationPoint = AtomicAllocator.getInstance().pickExternalBuffer(this); - this.allocationPoint.setPointers(new PointersPair(specialPointer, pointer)); - this.trackingPoint = allocationPoint.getObjectId(); this.allocationMode = AllocationMode.MIXED_DATA_TYPES; this.indexer = indexer; @@ -102,6 +110,12 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda this.length = length; initTypeAndSize(); + + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, this.type, false); + this.allocationPoint = new AllocationPoint(ptrDataBuffer, this.type.width() * length); + this.allocationPoint.setPointers(pointer, specialPointer, length); + + Nd4j.getDeallocatorService().pickObject(this); } /** @@ -114,10 +128,11 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda public BaseCudaDataBuffer(Pointer pointer, Indexer indexer, long length) { super(pointer, indexer, length); - //cuda specific bits - this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false); - this.trackingPoint = allocationPoint.getObjectId(); + // allocating interop buffer + this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false); + //cuda specific bits + this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * elementSize); Nd4j.getDeallocatorService().pickObject(this); // now we're @@ -222,71 +237,153 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda } public void lazyAllocateHostPointer() { - if (allocationPoint.getPointers().getHostPointer() == null) + if (length() == 0) + return; + + // java side might be unaware of native-side buffer allocation + if (this.indexer == null || this.pointer == null || this.pointer.address() == 0) { initHostPointerAndIndexer(); + } else if (allocationPoint.getHostPointer() != null && allocationPoint.getHostPointer().address() != this.pointer.address()) { + initHostPointerAndIndexer(); + } + } + + protected BaseCudaDataBuffer(ByteBuffer buffer, DataType dtype, long length, long offset) { + this(length, Nd4j.sizeOfDataType(dtype)); + + Pointer temp = null; + + switch (dataType()){ + case DOUBLE: + temp = new DoublePointer(buffer.asDoubleBuffer()); + break; + case FLOAT: + temp = new FloatPointer(buffer.asFloatBuffer()); + break; + case HALF: + temp = new ShortPointer(buffer.asShortBuffer()); + break; + case LONG: + temp = new LongPointer(buffer.asLongBuffer()); + break; + case INT: + temp = new IntPointer(buffer.asIntBuffer()); + break; + case SHORT: + temp = new ShortPointer(buffer.asShortBuffer()); + break; + case UBYTE: //Fall through + case BYTE: + temp = new BytePointer(buffer); + break; + case BOOL: + temp = new BooleanPointer(length()); + break; + case UTF8: + temp = new BytePointer(length()); + break; + case BFLOAT16: + temp = new ShortPointer(length()); + break; + case UINT16: + temp = new ShortPointer(length()); + break; + case UINT32: + temp = new IntPointer(length()); + break; + case UINT64: + temp = new LongPointer(length()); + break; + } + + // copy data to device + val stream = AtomicAllocator.getInstance().getDeviceContext().getSpecialStream(); + val ptr = ptrDataBuffer.specialBuffer(); + + if (offset > 0) + temp = new PagedPointer(temp.address() + offset * getElementSize()); + + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(ptr, temp, length * Nd4j.sizeOfDataType(dtype), CudaConstants.cudaMemcpyHostToDevice, stream); + stream.synchronize(); + + // mark device buffer as updated + allocationPoint.tickDeviceWrite(); } protected void initHostPointerAndIndexer() { - if (allocationPoint.getPointers().getHostPointer() == null) { + if (length() == 0) + return; + + if (allocationPoint.getHostPointer() == null) { val location = allocationPoint.getAllocationStatus(); if (parentWorkspace == null) { - val ptr = AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.HOST, this.allocationPoint, this.allocationPoint.getShape(), false); - this.allocationPoint.getPointers().setHostPointer(ptr.getHostPointer()); + //log.info("dbAllocate step"); + // let cpp allocate primary buffer + NativeOpsHolder.getInstance().getDeviceNativeOps().dbAllocatePrimaryBuffer(ptrDataBuffer); } else { + //log.info("ws alloc step"); val ptr = parentWorkspace.alloc(this.length * this.elementSize, MemoryKind.HOST, this.dataType(), false); - this.allocationPoint.getPointers().setHostPointer(ptr); + ptrDataBuffer.setPrimaryBuffer(ptr, this.length); } this.allocationPoint.setAllocationStatus(location); this.allocationPoint.tickDeviceWrite(); } + val hostPointer = allocationPoint.getHostPointer(); + + assert hostPointer != null; + switch (dataType()) { case DOUBLE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asDoublePointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asDoublePointer(); indexer = DoubleIndexer.create((DoublePointer) pointer); break; case FLOAT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asFloatPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asFloatPointer(); indexer = FloatIndexer.create((FloatPointer) pointer); break; case UINT32: case INT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); break; case BFLOAT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); indexer = Bfloat16Indexer.create((ShortPointer) pointer); break; case HALF: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); indexer = HalfIndexer.create((ShortPointer) pointer); break; case UINT64: case LONG: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer(); indexer = LongIndexer.create((LongPointer) pointer); break; case UINT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); indexer = UShortIndexer.create((ShortPointer) pointer); break; case SHORT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); indexer = ShortIndexer.create((ShortPointer) pointer); break; case UBYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asBytePointer(); indexer = UByteIndexer.create((BytePointer) pointer); break; case BYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asBytePointer(); indexer = ByteIndexer.create((BytePointer) pointer); break; case BOOL: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBooleanPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asBooleanPointer(); indexer = BooleanIndexer.create((BooleanPointer) pointer); break; + case UTF8: + this.pointer = new CudaPointer(hostPointer, length, 0).asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; default: throw new UnsupportedOperationException(); } @@ -294,21 +391,25 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda protected void initPointers(long length, int elementSize, boolean initialize) { this.allocationMode = AllocationMode.MIXED_DATA_TYPES; - this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), initialize); this.length = length; - //allocationPoint.attachBuffer(this); this.elementSize = (byte) elementSize; - this.trackingPoint = allocationPoint.getObjectId(); + this.offset = 0; this.originalOffset = 0; + // we allocate native DataBuffer AND it will contain our device pointer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false); + this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * type.width()); + + if (initialize) { + val ctx = AtomicAllocator.getInstance().getDeviceContext(); + val devicePtr = allocationPoint.getDevicePointer(); + NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream()); + ctx.getSpecialStream().synchronize(); + } + + // let deallocator pick up this object Nd4j.getDeallocatorService().pickObject(this); - - // if only host - if (allocationPoint.getPointers().getHostPointer() == null) - return; - - initHostPointerAndIndexer(); } public BaseCudaDataBuffer(long length, int elementSize, boolean initialize) { @@ -323,72 +424,45 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda this.attached = true; this.parentWorkspace = workspace; - this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, this.elementSize, dataType()), initialize); this.length = length; - this.trackingPoint = allocationPoint.getObjectId(); this.offset = 0; this.originalOffset = 0; + // allocating empty databuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, type, false); + + if (workspace.getWorkspaceConfiguration().getPolicyMirroring() == MirroringPolicy.FULL) { + val devicePtr = workspace.alloc(length * elementSize, MemoryKind.DEVICE, type, initialize); + + // allocate from workspace, and pass it to native DataBuffer + ptrDataBuffer.setSpecialBuffer(devicePtr, this.length); + + if (initialize) { + val ctx = AtomicAllocator.getInstance().getDeviceContext(); + NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream()); + ctx.getSpecialStream().synchronize(); + } + } else { + // we can register this pointer as device, because it's pinned memory + val devicePtr = workspace.alloc(length * elementSize, MemoryKind.HOST, type, initialize); + ptrDataBuffer.setSpecialBuffer(devicePtr, this.length); + + if (initialize) { + val ctx = AtomicAllocator.getInstance().getDeviceContext(); + NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream()); + ctx.getSpecialStream().synchronize(); + } + } + + this.allocationPoint = new AllocationPoint(ptrDataBuffer, elementSize * length); + + // registering for deallocation Nd4j.getDeallocatorService().pickObject(this); workspaceGenerationId = workspace.getGenerationId(); this.attached = true; this.parentWorkspace = workspace; - - if (allocationPoint.getHostPointer() == null) - return; - - switch (dataType()) { - case DOUBLE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asDoublePointer(); - indexer = DoubleIndexer.create((DoublePointer) pointer); - break; - case FLOAT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asFloatPointer(); - indexer = FloatIndexer.create((FloatPointer) pointer); - break; - case UINT32: - case INT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer(); - indexer = IntIndexer.create((IntPointer) pointer); - break; - case BFLOAT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = Bfloat16Indexer.create((ShortPointer) pointer); - break; - case HALF: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = HalfIndexer.create((ShortPointer) pointer); - break; - case UINT64: - case LONG: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer(); - indexer = LongIndexer.create((LongPointer) pointer); - break; - case BOOL: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBooleanPointer(); - indexer = BooleanIndexer.create((BooleanPointer) pointer); - break; - case UINT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = UShortIndexer.create((ShortPointer) pointer); - break; - case SHORT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = ShortIndexer.create((ShortPointer) pointer); - break; - case BYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer(); - indexer = ByteIndexer.create((BytePointer) pointer); - break; - case UBYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer(); - indexer = UByteIndexer.create((BytePointer) pointer); - break; - default: - throw new UnsupportedOperationException("Unknown data type: " + dataType()); - } } @Override @@ -427,57 +501,71 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda this.length = length; this.offset = offset; this.originalOffset = offset; - this.trackingPoint = underlyingBuffer.getTrackingPoint(); this.elementSize = (byte) underlyingBuffer.getElementSize(); - this.allocationPoint = ((BaseCudaDataBuffer) underlyingBuffer).allocationPoint; // in case of view creation, we initialize underlying buffer regardless of anything - ((BaseCudaDataBuffer) underlyingBuffer).lazyAllocateHostPointer();; + ((BaseCudaDataBuffer) underlyingBuffer).lazyAllocateHostPointer(); + + // we're creating view of the native DataBuffer + ptrDataBuffer = ((BaseCudaDataBuffer) underlyingBuffer).ptrDataBuffer.createView(length * underlyingBuffer.getElementSize(), offset * underlyingBuffer.getElementSize()); + this.allocationPoint = new AllocationPoint(ptrDataBuffer, length); + val hostPointer = allocationPoint.getHostPointer(); + + Nd4j.getDeallocatorService().pickObject(this); switch (underlyingBuffer.dataType()) { case DOUBLE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asDoublePointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asDoublePointer(); indexer = DoubleIndexer.create((DoublePointer) pointer); break; case FLOAT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asFloatPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asFloatPointer(); indexer = FloatIndexer.create((FloatPointer) pointer); break; case UINT32: case INT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asIntPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); break; case BFLOAT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); indexer = Bfloat16Indexer.create((ShortPointer) pointer); break; case HALF: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); indexer = HalfIndexer.create((ShortPointer) pointer); break; case UINT64: case LONG: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asLongPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer(); indexer = LongIndexer.create((LongPointer) pointer); break; case UINT16: + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); + indexer = UShortIndexer.create((ShortPointer) pointer); + break; case SHORT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); indexer = ShortIndexer.create((ShortPointer) pointer); break; case BOOL: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asBooleanPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBooleanPointer(); indexer = BooleanIndexer.create((BooleanPointer) pointer); break; case BYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asBytePointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBytePointer(); indexer = ByteIndexer.create((BytePointer) pointer); break; case UBYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asBytePointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBytePointer(); indexer = UByteIndexer.create((BytePointer) pointer); break; + case UTF8: + Preconditions.checkArgument(offset == 0, "String array can't be a view"); + + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; default: throw new UnsupportedOperationException(); } @@ -519,23 +607,6 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda set(data, data.length, 0, 0); } - public BaseCudaDataBuffer(byte[] data, long length, DataType type) { - this(ByteBuffer.wrap(data), length, type); - } - - public BaseCudaDataBuffer(ByteBuffer buffer, long length, DataType type) { - //super(buffer,length); - this(buffer, length, 0, type); - } - - public BaseCudaDataBuffer(ByteBuffer buffer, long length, long offset, DataType type) { - //super(buffer, length, offset); - this(length, Nd4j.sizeOfDataType(type), offset); - - Pointer srcPtr = new CudaPointer(new Pointer(buffer.order(ByteOrder.nativeOrder()))); - - allocator.memcpyAsync(this, srcPtr, length * elementSize, offset * elementSize); - } /** * This method always returns host pointer @@ -547,12 +618,12 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda if (released) throw new IllegalStateException("You can't use DataBuffer once it was released"); - return allocationPoint.getPointers().getHostPointer().address(); + return allocationPoint.getHostPointer().address(); } @Override public long platformAddress() { - return allocationPoint.getPointers().getDevicePointer().address(); + return allocationPoint.getDevicePointer().address(); } @Override @@ -582,7 +653,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda switch (dataType()) { case BOOL: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -592,7 +663,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case BYTE: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -608,7 +679,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case SHORT: { val pointer = new ShortPointer(ArrayUtil.toShorts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -618,7 +689,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case INT: { val pointer = new IntPointer(data); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -628,7 +699,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case LONG: { val pointer = new LongPointer(LongUtils.toLongs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -638,7 +709,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case HALF: { val pointer = new ShortPointer(ArrayUtil.toHalfs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -648,7 +719,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case FLOAT: { val pointer = new FloatPointer(ArrayUtil.toFloats(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -658,7 +729,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case DOUBLE: { val pointer = new DoublePointer(ArrayUtil.toDouble(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -677,7 +748,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda switch (dataType()) { case BOOL: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -687,7 +758,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case BYTE: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -706,7 +777,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda data = ArrayUtil.cutBelowZero(data); case SHORT: { val pointer = new ShortPointer(ArrayUtil.toShorts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -718,7 +789,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda data = ArrayUtil.cutBelowZero(data); case INT: { val pointer = new IntPointer(ArrayUtil.toInts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -730,7 +801,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda data = ArrayUtil.cutBelowZero(data); case LONG: { val pointer = new LongPointer(data); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -740,7 +811,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case BFLOAT16: { val pointer = new ShortPointer(ArrayUtil.toBfloats(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -750,7 +821,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case HALF: { val pointer = new ShortPointer(ArrayUtil.toHalfs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -760,7 +831,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case FLOAT: { val pointer = new FloatPointer(ArrayUtil.toFloats(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -770,7 +841,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case DOUBLE: { val pointer = new DoublePointer(ArrayUtil.toDouble(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); // we're keeping pointer reference for JVM @@ -796,7 +867,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda switch (dataType()) { case BOOL: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -806,7 +877,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case BYTE: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -822,7 +893,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case SHORT: { val pointer = new ShortPointer(ArrayUtil.toShorts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -832,7 +903,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case INT: { val pointer = new IntPointer(ArrayUtil.toInts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -842,7 +913,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case LONG: { val pointer = new LongPointer(ArrayUtil.toLongArray(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -852,7 +923,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case HALF: { val pointer = new ShortPointer(ArrayUtil.toHalfs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -862,7 +933,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case FLOAT: { val pointer = new FloatPointer(data); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -872,7 +943,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case DOUBLE: { DoublePointer pointer = new DoublePointer(ArrayUtil.toDoubles(data)); - Pointer srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + Pointer srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -898,7 +969,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda switch (dataType()) { case BOOL: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -908,7 +979,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case BYTE: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -924,7 +995,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case SHORT: { val pointer = new ShortPointer(ArrayUtil.toShorts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -934,7 +1005,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case INT: { val pointer = new IntPointer(ArrayUtil.toInts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -944,7 +1015,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case LONG: { val pointer = new LongPointer(ArrayUtil.toLongs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -954,7 +1025,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case HALF: { val pointer = new ShortPointer(ArrayUtil.toHalfs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -964,7 +1035,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case FLOAT: { val pointer = new FloatPointer(ArrayUtil.toFloats(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -974,7 +1045,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case DOUBLE: { val pointer = new DoublePointer(data); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -1249,7 +1320,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override public boolean sameUnderlyingData(DataBuffer buffer) { - return buffer.getTrackingPoint() == getTrackingPoint(); + return ptrDataBuffer.address() == ((BaseCudaDataBuffer) buffer).ptrDataBuffer.address(); } /** @@ -1342,54 +1413,54 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda this.elementSize = (byte) Nd4j.sizeOfDataType(t); this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, t), false); - this.trackingPoint = allocationPoint.getObjectId(); + this.type = t; Nd4j.getDeallocatorService().pickObject(this); switch (type) { case DOUBLE: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asDoublePointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asDoublePointer(); indexer = DoubleIndexer.create((DoublePointer) pointer); } break; case FLOAT: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asFloatPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asFloatPointer(); indexer = FloatIndexer.create((FloatPointer) pointer); } break; case HALF: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asShortPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asShortPointer(); indexer = HalfIndexer.create((ShortPointer) pointer); } break; case LONG: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asLongPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asLongPointer(); indexer = LongIndexer.create((LongPointer) pointer); } break; case INT: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asIntPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); } break; case SHORT: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asShortPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asShortPointer(); indexer = ShortIndexer.create((ShortPointer) pointer); } break; case UBYTE: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asBytePointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asBytePointer(); indexer = UByteIndexer.create((BytePointer) pointer); } break; case BYTE: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asBytePointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asBytePointer(); indexer = ByteIndexer.create((BytePointer) pointer); } break; case BOOL: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asBooleanPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asBooleanPointer(); indexer = BooleanIndexer.create((BooleanPointer) pointer); } break; @@ -1511,53 +1582,181 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda return super.getInt(ix); } + public void actualizePointerAndIndexer() { + val cptr = ptrDataBuffer.primaryBuffer(); + + // skip update if pointers are equal + if (cptr != null && pointer != null && cptr.address() == pointer.address()) + return; + + val t = dataType(); + if (t == DataType.BOOL) { + pointer = new PagedPointer(cptr, length).asBoolPointer(); + setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); + } else if (t == DataType.UBYTE) { + pointer = new PagedPointer(cptr, length).asBytePointer(); + setIndexer(UByteIndexer.create((BytePointer) pointer)); + } else if (t == DataType.BYTE) { + pointer = new PagedPointer(cptr, length).asBytePointer(); + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else if (t == DataType.UINT16) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(UShortIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.SHORT) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(ShortIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.UINT32) { + pointer = new PagedPointer(cptr, length).asIntPointer(); + setIndexer(IntIndexer.create((IntPointer) pointer)); + } else if (t == DataType.INT) { + pointer = new PagedPointer(cptr, length).asIntPointer(); + setIndexer(IntIndexer.create((IntPointer) pointer)); + } else if (t == DataType.UINT64) { + pointer = new PagedPointer(cptr, length).asLongPointer(); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (t == DataType.LONG) { + pointer = new PagedPointer(cptr, length).asLongPointer(); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (t == DataType.BFLOAT16) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); + } else if (t == DataType.HALF) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(HalfIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.FLOAT) { + pointer = new PagedPointer(cptr, length).asFloatPointer(); + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + } else if (t == DataType.DOUBLE) { + pointer = new PagedPointer(cptr, length).asDoublePointer(); + setIndexer(DoubleIndexer.create((DoublePointer) pointer)); + } else if (t == DataType.UTF8) { + pointer = new PagedPointer(cptr, length()).asBytePointer(); + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else + throw new IllegalArgumentException("Unknown datatype: " + dataType()); + } + @Override public DataBuffer reallocate(long length) { + val oldHostPointer = this.ptrDataBuffer.primaryBuffer(); + val oldDevicePointer = this.ptrDataBuffer.specialBuffer(); - // we want to be sure this array isn't used anywhere RIGHT AT THIS MOMENT - Nd4j.getExecutioner().commit(); + if (isAttached()) { + val capacity = length * getElementSize(); + + if (oldDevicePointer != null && oldDevicePointer.address() != 0) { + val nPtr = getParentWorkspace().alloc(capacity, MemoryKind.DEVICE, dataType(), false); + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpySync(nPtr, oldDevicePointer, length * getElementSize(), 3, null); + this.ptrDataBuffer.setPrimaryBuffer(nPtr, length); + + allocationPoint.tickDeviceRead(); + } + + if (oldHostPointer != null && oldHostPointer.address() != 0) { + val nPtr = getParentWorkspace().alloc(capacity, MemoryKind.HOST, dataType(), false); + Pointer.memcpy(nPtr, oldHostPointer, this.length() * getElementSize()); + this.ptrDataBuffer.setPrimaryBuffer(nPtr, length); + + allocationPoint.tickHostRead(); + + switch (dataType()) { + case BOOL: + pointer = nPtr.asBoolPointer(); + indexer = BooleanIndexer.create((BooleanPointer) pointer); + break; + case UTF8: + case BYTE: + case UBYTE: + pointer = nPtr.asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; + case UINT16: + case SHORT: + pointer = nPtr.asShortPointer(); + indexer = ShortIndexer.create((ShortPointer) pointer); + break; + case UINT32: + case INT: + pointer = nPtr.asIntPointer(); + indexer = IntIndexer.create((IntPointer) pointer); + break; + case DOUBLE: + pointer = nPtr.asDoublePointer(); + indexer = DoubleIndexer.create((DoublePointer) pointer); + break; + case FLOAT: + pointer = nPtr.asFloatPointer(); + indexer = FloatIndexer.create((FloatPointer) pointer); + break; + case HALF: + pointer = nPtr.asShortPointer(); + indexer = HalfIndexer.create((ShortPointer) pointer); + break; + case BFLOAT16: + pointer = nPtr.asShortPointer(); + indexer = Bfloat16Indexer.create((ShortPointer) pointer); + break; + case UINT64: + case LONG: + pointer = nPtr.asLongPointer(); + indexer = LongIndexer.create((LongPointer) pointer); + break; + } + } - AllocationPoint old = allocationPoint; - allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false); + workspaceGenerationId = getParentWorkspace().getGenerationId(); + } else { + this.ptrDataBuffer.expand(length); + val nPtr = new PagedPointer(this.ptrDataBuffer.primaryBuffer(), length); - Nd4j.getDeallocatorService().pickObject(this); - trackingPoint = allocationPoint.getObjectId(); - val oldLength = this.length; + switch (dataType()) { + case BOOL: + pointer = nPtr.asBoolPointer(); + indexer = BooleanIndexer.create((BooleanPointer) pointer); + break; + case UTF8: + case BYTE: + case UBYTE: + pointer = nPtr.asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; + case UINT16: + case SHORT: + pointer = nPtr.asShortPointer(); + indexer = ShortIndexer.create((ShortPointer) pointer); + break; + case UINT32: + case INT: + pointer = nPtr.asIntPointer(); + indexer = IntIndexer.create((IntPointer) pointer); + break; + case DOUBLE: + pointer = nPtr.asDoublePointer(); + indexer = DoubleIndexer.create((DoublePointer) pointer); + break; + case FLOAT: + pointer = nPtr.asFloatPointer(); + indexer = FloatIndexer.create((FloatPointer) pointer); + break; + case HALF: + pointer = nPtr.asShortPointer(); + indexer = HalfIndexer.create((ShortPointer) pointer); + break; + case BFLOAT16: + pointer = nPtr.asShortPointer(); + indexer = Bfloat16Indexer.create((ShortPointer) pointer); + break; + case UINT64: + case LONG: + pointer = nPtr.asLongPointer(); + indexer = LongIndexer.create((LongPointer) pointer); + break; + } + } + + this.underlyingLength = length; this.length = length; - - // if original buffer had host pointer allocated, we'll reallocate host buffer as well - if (old.getHostPointer() != null) { - lazyAllocateHostPointer(); - } - - val context = AtomicAllocator.getInstance().getDeviceContext(); - NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream()); - - MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; - val perfD = PerformanceTracker.getInstance().helperStartTransaction(); - - if (old.isActualOnDeviceSide()) { - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getDevicePointer(), oldLength * elementSize, CudaConstants.cudaMemcpyDeviceToDevice, context.getSpecialStream()); - } else if (old.isActualOnHostSide()) { - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getHostPointer(), oldLength * elementSize, CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); - direction = MemcpyDirection.HOST_TO_DEVICE; - } - - context.getSpecialStream().synchronize(); - - PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD, allocationPoint.getNumberOfBytes(), direction); - - allocationPoint.tickDeviceWrite(); - - // we need to update length with new value now - //this.length = length; - if(isAttached()){ - // do nothing here, that's workspaces - } else{ - AtomicAllocator.getInstance().freeMemory(old); - } - return this; } @@ -1572,7 +1771,8 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override protected void release() { if (!released) { - AtomicAllocator.getInstance().freeMemory(allocationPoint); + //AtomicAllocator.getInstance().freeMemory(allocationPoint);n + NativeOpsHolder.getInstance().getDeviceNativeOps().dbClose(allocationPoint.getPtrDataBuffer()); allocationPoint.setReleased(true); } released = true; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBfloat16DataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBfloat16DataBuffer.java index 193a9e21c..145816a5e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBfloat16DataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBfloat16DataBuffer.java @@ -46,6 +46,10 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaBfloat16DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -128,18 +132,6 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaBfloat16DataBuffer(byte[] data, long length) { - super(data, length, DataType.BFLOAT16); - } - - public CudaBfloat16DataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.BFLOAT16); - } - - public CudaBfloat16DataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.BFLOAT16); - } - @Override public void assign(long[] indices, double[] data, boolean contiguous, long inc) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBoolDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBoolDataBuffer.java index a1b498785..08dbd9f39 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBoolDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBoolDataBuffer.java @@ -50,6 +50,10 @@ public class CudaBoolDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaBoolDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -132,18 +136,6 @@ public class CudaBoolDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaBoolDataBuffer(byte[] data, long length) { - super(data, length, DataType.HALF); - } - - public CudaBoolDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.HALF); - } - - public CudaBoolDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.HALF); - } - @Override protected DataBuffer create(long length) { return new CudaBoolDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaByteDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaByteDataBuffer.java index 80fb7f804..d35b3c215 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaByteDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaByteDataBuffer.java @@ -49,6 +49,10 @@ public class CudaByteDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaByteDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -131,18 +135,6 @@ public class CudaByteDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaByteDataBuffer(byte[] data, long length) { - super(data, length, DataType.HALF); - } - - public CudaByteDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.HALF); - } - - public CudaByteDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.HALF); - } - @Override protected DataBuffer create(long length) { return new CudaByteDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaDoubleDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaDoubleDataBuffer.java index 8ccc3cf81..789b213f1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaDoubleDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaDoubleDataBuffer.java @@ -49,6 +49,10 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaDoubleDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -138,18 +142,6 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaDoubleDataBuffer(byte[] data, long length) { - super(data, length, DataType.DOUBLE); - } - - public CudaDoubleDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.DOUBLE); - } - - public CudaDoubleDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.DOUBLE); - } - @Override protected DataBuffer create(long length) { return new CudaDoubleDataBuffer(length); @@ -210,14 +202,7 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer { this.length = n; this.elementSize = 8; - //wrappedBuffer = ByteBuffer.allocateDirect(length() * getElementSize()); - //wrappedBuffer.order(ByteOrder.nativeOrder()); - - this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, - new AllocationShape(length, elementSize, DataType.DOUBLE), false); - this.trackingPoint = allocationPoint.getObjectId(); - //this.wrappedBuffer = allocationPoint.getPointers().getHostPointer().asByteBuffer(); - //this.wrappedBuffer.order(ByteOrder.nativeOrder()); + this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, DataType.DOUBLE), false); setData(arr); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java index c173e2745..f7f70bc75 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java @@ -50,6 +50,10 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaFloatDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -133,19 +137,6 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaFloatDataBuffer(byte[] data, long length) { - super(data, length, DataType.FLOAT); - } - - public CudaFloatDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.FLOAT); - } - - public CudaFloatDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.FLOAT); - } - - @Override protected DataBuffer create(long length) { return new CudaFloatDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaHalfDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaHalfDataBuffer.java index 472e701c1..1fb55e73b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaHalfDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaHalfDataBuffer.java @@ -49,6 +49,10 @@ public class CudaHalfDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaHalfDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -131,18 +135,6 @@ public class CudaHalfDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaHalfDataBuffer(byte[] data, long length) { - super(data, length, DataType.HALF); - } - - public CudaHalfDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.HALF); - } - - public CudaHalfDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.HALF); - } - @Override protected DataBuffer create(long length) { return new CudaHalfDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaIntDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaIntDataBuffer.java index 27c0c95e3..95a9c0ce9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaIntDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaIntDataBuffer.java @@ -46,6 +46,10 @@ public class CudaIntDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaIntDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -106,11 +110,6 @@ public class CudaIntDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - - public CudaIntDataBuffer(byte[] data, int length) { - super(data, length, DataType.INT); - } - public CudaIntDataBuffer(double[] data) { super(data); } @@ -135,14 +134,6 @@ public class CudaIntDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaIntDataBuffer(ByteBuffer buffer, int length) { - super(buffer, length, DataType.INT); - } - - public CudaIntDataBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset, DataType.INT); - } - @Override protected DataBuffer create(long length) { return new CudaIntDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaLongDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaLongDataBuffer.java index 494148862..381ab5355 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaLongDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaLongDataBuffer.java @@ -16,12 +16,14 @@ package org.nd4j.linalg.jcublas.buffer; +import lombok.Data; import lombok.NonNull; import lombok.val; import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; import org.bytedeco.javacpp.indexer.LongIndexer; +import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AllocationShape; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.CudaPointer; @@ -30,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.util.ArrayUtil; +import org.nd4j.nativeblas.NativeOpsHolder; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -55,8 +58,18 @@ public class CudaLongDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaLongDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** + * This constructor is special one - it's used for ShapeInfo + * @param hostPointer + * @param devicePointer + * @param numberOfElements + */ public CudaLongDataBuffer(@NonNull Pointer hostPointer, @NonNull Pointer devicePointer, long numberOfElements) { + super(); this.allocationMode = AllocationMode.MIXED_DATA_TYPES; this.offset = 0; this.originalOffset = 0; @@ -64,14 +77,15 @@ public class CudaLongDataBuffer extends BaseCudaDataBuffer { this.length = numberOfElements; initTypeAndSize(); + // creating empty native DataBuffer and filling it with pointers + ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, hostPointer, numberOfElements); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(ptrDataBuffer, devicePointer, numberOfElements); + + // setting up java side of things this.pointer = new CudaPointer(hostPointer, numberOfElements).asLongPointer(); indexer = LongIndexer.create((LongPointer) this.pointer); - - this.allocationPoint = AtomicAllocator.getInstance().pickExternalBuffer(this); - - val pp = new PointersPair(devicePointer, this.pointer); - allocationPoint.setPointers(pp); - trackingPoint = allocationPoint.getObjectId(); + this.allocationPoint = new AllocationPoint(ptrDataBuffer, numberOfElements * DataType.INT64.width()); } /** @@ -179,19 +193,6 @@ public class CudaLongDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaLongDataBuffer(byte[] data, long length) { - super(data, length, DataType.LONG); - } - - public CudaLongDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.LONG); - } - - public CudaLongDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.LONG); - } - - @Override protected DataBuffer create(long length) { return new CudaLongDataBuffer(length); @@ -241,14 +242,7 @@ public class CudaLongDataBuffer extends BaseCudaDataBuffer { this.length = n; this.elementSize = 8; - //wrappedBuffer = ByteBuffer.allocateDirect(length() * getElementSize()); - //wrappedBuffer.order(ByteOrder.nativeOrder()); - - this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, - new AllocationShape(length, elementSize, DataType.LONG), false); - this.trackingPoint = allocationPoint.getObjectId(); - //this.wrappedBuffer = allocationPoint.getPointers().getHostPointer().asByteBuffer(); - //this.wrappedBuffer.order(ByteOrder.nativeOrder()); + this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, DataType.LONG), false); setData(arr); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaShortDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaShortDataBuffer.java index 9a67f56aa..645b06723 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaShortDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaShortDataBuffer.java @@ -49,6 +49,10 @@ public class CudaShortDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaShortDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -131,18 +135,6 @@ public class CudaShortDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaShortDataBuffer(byte[] data, long length) { - super(data, length, DataType.HALF); - } - - public CudaShortDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.HALF); - } - - public CudaShortDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.HALF); - } - @Override protected DataBuffer create(long length) { return new CudaShortDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUByteDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUByteDataBuffer.java index 7cc944850..5447ba043 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUByteDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUByteDataBuffer.java @@ -49,6 +49,10 @@ public class CudaUByteDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaUByteDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -131,18 +135,6 @@ public class CudaUByteDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaUByteDataBuffer(byte[] data, long length) { - super(data, length, DataType.HALF); - } - - public CudaUByteDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.HALF); - } - - public CudaUByteDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.HALF); - } - @Override public void assign(long[] indices, double[] data, boolean contiguous, long inc) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt16DataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt16DataBuffer.java index 428cb5bcd..809363494 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt16DataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt16DataBuffer.java @@ -46,6 +46,10 @@ public class CudaUInt16DataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaUInt16DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -128,18 +132,6 @@ public class CudaUInt16DataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaUInt16DataBuffer(byte[] data, long length) { - super(data, length, DataType.UINT16); - } - - public CudaUInt16DataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.UINT16); - } - - public CudaUInt16DataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.UINT16); - } - @Override public void assign(long[] indices, double[] data, boolean contiguous, long inc) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt32DataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt32DataBuffer.java index cd34607ce..1595cfda3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt32DataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt32DataBuffer.java @@ -46,6 +46,10 @@ public class CudaUInt32DataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaUInt32DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -128,18 +132,6 @@ public class CudaUInt32DataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaUInt32DataBuffer(byte[] data, long length) { - super(data, length, DataType.UINT32); - } - - public CudaUInt32DataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.UINT32); - } - - public CudaUInt32DataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.UINT32); - } - @Override public void assign(long[] indices, double[] data, boolean contiguous, long inc) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt64DataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt64DataBuffer.java index 0e413827c..a107a5d8c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt64DataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt64DataBuffer.java @@ -42,6 +42,10 @@ public class CudaUInt64DataBuffer extends BaseCudaDataBuffer { super(pointer, indexer, length); } + public CudaUInt64DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -128,18 +132,6 @@ public class CudaUInt64DataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaUInt64DataBuffer(byte[] data, long length) { - super(data, length, DataType.UINT64); - } - - public CudaUInt64DataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.UINT64); - } - - public CudaUInt64DataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.UINT64); - } - @Override public void assign(long[] indices, double[] data, boolean contiguous, long inc) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUtf8Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUtf8Buffer.java new file mode 100644 index 000000000..50219f563 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUtf8Buffer.java @@ -0,0 +1,243 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.jcublas.buffer; + + +import lombok.Getter; +import lombok.NonNull; +import lombok.val; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.LongPointer; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.base.Preconditions; +import org.nd4j.jita.allocator.impl.AtomicAllocator; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.MemoryWorkspace; + +import java.io.UnsupportedEncodingException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; + +/** + * UTF-8 buffer + * + * @author Adam Gibson + */ +public class CudaUtf8Buffer extends BaseCudaDataBuffer { + + protected Collection references = new ArrayList<>(); + + @Getter + protected long numWords = 0; + + /** + * Meant for creating another view of a buffer + * + * @param pointer the underlying buffer to create a view from + * @param indexer the indexer for the pointer + * @param length the length of the view + */ + public CudaUtf8Buffer(Pointer pointer, Indexer indexer, long length) { + super(pointer, indexer, length); + } + + public CudaUtf8Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + + public CudaUtf8Buffer(long length) { + super(length); + } + + public CudaUtf8Buffer(long length, boolean initialize) { + super((length + 1) * 8, 1, initialize); + numWords = length; + } + + public CudaUtf8Buffer(long length, boolean initialize, MemoryWorkspace workspace) { + super((length + 1) * 8, 1, initialize, workspace); + numWords = length; + } + + public CudaUtf8Buffer(int[] ints, boolean copy, MemoryWorkspace workspace) { + super(ints, copy, workspace); + } + + public CudaUtf8Buffer(byte[] data, long numWords) { + super(data.length, 1, false); + + lazyAllocateHostPointer(); + + val bp = (BytePointer) pointer; + bp.put(data); + this.numWords = numWords; + } + + public CudaUtf8Buffer(double[] data, boolean copy) { + super(data, copy); + } + + public CudaUtf8Buffer(double[] data, boolean copy, long offset) { + super(data, copy, offset); + } + + public CudaUtf8Buffer(float[] data, boolean copy) { + super(data, copy); + } + + public CudaUtf8Buffer(long[] data, boolean copy) { + super(data, copy); + } + + public CudaUtf8Buffer(long[] data, boolean copy, MemoryWorkspace workspace) { + super(data, copy); + } + + public CudaUtf8Buffer(float[] data, boolean copy, long offset) { + super(data, copy, offset); + } + + public CudaUtf8Buffer(int[] data, boolean copy, long offset) { + super(data, copy, offset); + } + + public CudaUtf8Buffer(int length, int elementSize) { + super(length, elementSize); + } + + public CudaUtf8Buffer(int length, int elementSize, long offset) { + super(length, elementSize, offset); + } + + public CudaUtf8Buffer(DataBuffer underlyingBuffer, long length, long offset) { + super(underlyingBuffer, length, offset); + this.numWords = length; + + Preconditions.checkArgument(((CudaUtf8Buffer) underlyingBuffer).numWords == numWords, "String array can't be a view"); + } + + public CudaUtf8Buffer(@NonNull Collection strings) { + super(CudaUtf8Buffer.stringBufferRequiredLength(strings), 1, false); + lazyAllocateHostPointer(); + + // at this point we should have fully allocated buffer, time to fill length + val headerLength = (strings.size() + 1) * 8; + val headerPointer = new LongPointer(this.pointer); + val dataPointer = new BytePointer(this.pointer); + + numWords = strings.size(); + + long cnt = 0; + long currentLength = 0; + for (val s: strings) { + headerPointer.put(cnt++, currentLength); + val length = s.length(); + val chars = s.toCharArray(); + + // putting down chars + for (int e = 0; e < length; e++) { + val b = (byte) chars[e]; + val idx = headerLength + currentLength + e; + dataPointer.put(idx, b); + } + + currentLength += length; + } + headerPointer.put(cnt, currentLength); + allocationPoint.tickHostWrite(); + } + + public String getString(long index) { + if (index > numWords) + throw new IllegalArgumentException("Requested index [" + index + "] is above actual number of words stored: [" + numWords + "]"); + + val headerPointer = new LongPointer(this.pointer); + val dataPointer = (BytePointer) (this.pointer); + + val start = headerPointer.get(index); + val end = headerPointer.get(index+1); + + if (end - start > Integer.MAX_VALUE) + throw new IllegalStateException("Array is too long for Java"); + + val dataLength = (int) (end - start); + val bytes = new byte[dataLength]; + + val headerLength = (numWords + 1) * 8; + + for (int e = 0; e < dataLength; e++) { + val idx = headerLength + start + e; + bytes[e] = dataPointer.get(idx); + } + + try { + return new String(bytes, "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } + + @Override + protected DataBuffer create(long length) { + return new CudaUtf8Buffer(length); + } + + @Override + public DataBuffer create(double[] data) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer create(float[] data) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer create(int[] data) { + throw new UnsupportedOperationException(); + } + + private static long stringBufferRequiredLength(@NonNull Collection strings) { + // header size first + long size = (strings.size() + 1) * 8; + + for (val s:strings) + size += s.length(); + + return size; + } + + public void put(long index, Pointer pointer) { + throw new UnsupportedOperationException(); + //references.add(pointer); + //((LongIndexer) indexer).put(index, pointer.address()); + } + + /** + * Initialize the opType of this buffer + */ + @Override + protected void initTypeAndSize() { + elementSize = 1; + type = DataType.UTF8; + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/factory/CudaDataBufferFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/factory/CudaDataBufferFactory.java index 72e089e45..5083a2bf9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/factory/CudaDataBufferFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/factory/CudaDataBufferFactory.java @@ -24,15 +24,11 @@ import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.*; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.LongBuffer; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.buffer.factory.DataBufferFactory; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.jcublas.buffer.*; import org.nd4j.linalg.util.ArrayUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.nio.ByteBuffer; @@ -64,6 +60,42 @@ public class CudaDataBufferFactory implements DataBufferFactory { return allocationMode; } + @Override + public DataBuffer create(ByteBuffer underlyingBuffer, DataType dataType, long length, long offset) { + switch (dataType) { + case DOUBLE: + return new CudaDoubleDataBuffer(underlyingBuffer, dataType, length, offset); + case FLOAT: + return new CudaFloatDataBuffer(underlyingBuffer, dataType, length, offset); + case HALF: + return new CudaHalfDataBuffer(underlyingBuffer, dataType, length, offset); + case BFLOAT16: + return new CudaBfloat16DataBuffer(underlyingBuffer, dataType, length, offset); + case LONG: + return new CudaLongDataBuffer(underlyingBuffer, dataType, length, offset); + case INT: + return new CudaIntDataBuffer(underlyingBuffer, dataType, length, offset); + case SHORT: + return new CudaShortDataBuffer(underlyingBuffer, dataType, length, offset); + case UBYTE: + return new CudaUByteDataBuffer(underlyingBuffer, dataType, length, offset); + case UINT16: + return new CudaUInt16DataBuffer(underlyingBuffer, dataType, length, offset); + case UINT32: + return new CudaUInt32DataBuffer(underlyingBuffer, dataType, length, offset); + case UINT64: + return new CudaUInt64DataBuffer(underlyingBuffer, dataType, length, offset); + case BYTE: + return new CudaByteDataBuffer(underlyingBuffer, dataType, length, offset); + case BOOL: + return new CudaBoolDataBuffer(underlyingBuffer, dataType, length, offset); + case UTF8: + return new CudaUtf8Buffer(underlyingBuffer, dataType, length, offset); + default: + throw new IllegalStateException("Unknown datatype used: [" + dataType + "]"); + } + } + @Override public DataBuffer create(DataBuffer underlyingBuffer, long offset, long length) { switch (underlyingBuffer.dataType()) { @@ -94,7 +126,7 @@ public class CudaDataBufferFactory implements DataBufferFactory { case BOOL: return new CudaBoolDataBuffer(underlyingBuffer, length, offset); case UTF8: - return new Utf8Buffer(underlyingBuffer, length, offset); + return new CudaUtf8Buffer(underlyingBuffer, length, offset); default: throw new ND4JIllegalStateException("Unknown data buffer type: " + underlyingBuffer.dataType().toString()); } @@ -169,27 +201,6 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaIntDataBuffer(data, copy, workspace); } - @Override - public DataBuffer createInt(long offset, ByteBuffer buffer, int length) { - return new CudaIntDataBuffer(buffer, length, offset); - } - - @Override - public DataBuffer createFloat(long offset, ByteBuffer buffer, int length) { - return new CudaFloatDataBuffer(buffer, length, offset); - } - - @Override - public DataBuffer createDouble(long offset, ByteBuffer buffer, int length) { - return new CudaDoubleDataBuffer(buffer, length, offset); - } - - - @Override - public DataBuffer createLong(ByteBuffer buffer, int length) { - return new CudaLongDataBuffer(buffer, length); - } - @Override public DataBuffer createDouble(long offset, int length) { return new CudaDoubleDataBuffer(length, 8, offset); @@ -315,21 +326,6 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaIntDataBuffer(data, copy, offset); } - @Override - public DataBuffer createInt(ByteBuffer buffer, int length) { - return new CudaIntDataBuffer(buffer, length); - } - - @Override - public DataBuffer createFloat(ByteBuffer buffer, int length) { - return new CudaFloatDataBuffer(buffer, length); - } - - @Override - public DataBuffer createDouble(ByteBuffer buffer, int length) { - return new CudaDoubleDataBuffer(buffer, length); - } - @Override public DataBuffer createDouble(long length) { return new CudaDoubleDataBuffer(length); @@ -384,6 +380,8 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaHalfDataBuffer(length, initialize); case BOOL: return new CudaBoolDataBuffer(length, initialize); + case UTF8: + return new CudaUtf8Buffer(length, true); default: throw new UnsupportedOperationException("Unknown data type: [" + dataType + "]"); } @@ -581,16 +579,6 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaDoubleDataBuffer(data); } - @Override - public DataBuffer createDouble(byte[] data, int length) { - return new CudaDoubleDataBuffer(data, length); - } - - @Override - public DataBuffer createFloat(byte[] data, int length) { - return new CudaFloatDataBuffer(data, length); - } - @Override public DataBuffer createFloat(double[] data) { return new CudaFloatDataBuffer(ArrayUtil.toFloats(data)); @@ -969,18 +957,6 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaHalfDataBuffer(data); } - /** - * Creates a half-precision data buffer - * - * @param offset - * @param data the data to create the buffer from - * @param length - * @return the new buffer - */ - @Override - public DataBuffer createHalf(long offset, byte[] data, int length) { - return new CudaHalfDataBuffer(ArrayUtil.toFloatArray(data), true, offset); - } /** * Creates a half-precision data buffer @@ -994,30 +970,6 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaHalfDataBuffer(length); } - /** - * Creates a half-precision data buffer - * - * @param buffer - * @param length - * @return the new buffer - */ - @Override - public DataBuffer createHalf(ByteBuffer buffer, int length) { - return new CudaHalfDataBuffer(buffer, length); - } - - /** - * Creates a half-precision data buffer - * - * @param data - * @param length - * @return - */ - @Override - public DataBuffer createHalf(byte[] data, int length) { - return new CudaHalfDataBuffer(data, length); - } - @Override public DataBuffer createDouble(long length, boolean initialize, MemoryWorkspace workspace) { return new CudaDoubleDataBuffer(length, initialize, workspace); @@ -1124,4 +1076,7 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaLongDataBuffer(length, initialize, workspace); } + public DataBuffer createUtf8Buffer(byte[] data, long product) { + return new CudaUtf8Buffer(data, product); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java index 19e8f8df6..f9cbb1794 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java @@ -24,28 +24,18 @@ import org.apache.commons.math3.util.FastMath; import org.bytedeco.javacpp.*; import org.nd4j.compression.impl.AbstractCompressor; import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataTypeEx; -import org.nd4j.linalg.api.buffer.IntBuffer; -import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.compression.CompressedDataBuffer; -import org.nd4j.linalg.compression.CompressionDescriptor; import org.nd4j.linalg.compression.CompressionType; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan; -import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer; import org.nd4j.linalg.jcublas.context.CudaContext; -import org.nd4j.linalg.ops.transforms.Transforms; -import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 16568fbf4..f1bbb6d04 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -24,10 +24,8 @@ import lombok.val; import lombok.var; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.LongIndexer; -import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.base.Preconditions; -import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.tad.DeviceTADManager; @@ -36,7 +34,6 @@ import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.environment.Nd4jEnvironment; import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.api.ndarray.INDArray; @@ -50,6 +47,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -58,13 +56,13 @@ import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayType; import org.nd4j.linalg.cache.TADManager; import org.nd4j.linalg.compression.ThresholdCompression; -import org.nd4j.linalg.exception.ND4JIllegalArgumentException; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.buffer.AddressRetriever; import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer; +import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.primitives.AtomicBoolean; import org.nd4j.linalg.primitives.Pair; @@ -131,12 +129,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { val dimension = op.dimensions().toIntVector(); -// validateDataType(Nd4j.dataType(), op); - if (extraz.get() == null) extraz.set(new PointerPointer(32)); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); @@ -146,9 +142,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); - Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context); - Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension); @@ -185,23 +182,18 @@ public class CudaExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case BROADCAST: nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo, - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + x, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), (LongPointer) xShapeInfo, + y, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), + z, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.dimensions().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.dimensions().shapeInfoDataBuffer(), context)); break; case BROADCAST_BOOL: nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo, - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - null, null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + x, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), (LongPointer) xShapeInfo, + y, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), + z, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + null, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.dimensions().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.dimensions().shapeInfoDataBuffer(), context)); break; default: throw new UnsupportedOperationException("Unknown op type: " + op.getOpType()); @@ -210,9 +202,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - profilingConfigurableHookOut(op, st); return op.z(); @@ -252,7 +241,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]"); - val context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); @@ -269,7 +258,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { DataBuffer offsets = tadBuffers.getSecond(); Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); if (extraz.get() == null) @@ -333,150 +321,118 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(argsType), context) : null; Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); //AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context); - if (op instanceof Variance) { - if (ret.isScalar()) { - nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()), - ((Variance) op).isBiasCorrected()); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - } else { - nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, ((Variance) op).isBiasCorrected(), - (LongPointer) devTadShapeInfo, - (LongPointer) devTadOffsets); - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - } - } else if (op.y() != null) { - if (op.isComplexAccumulation()) { - - val dT = new LongPointerWrapper(devTadOffsets); - val yT = new LongPointerWrapper(yDevTadOffsets); - - nativeOps.execReduce3All(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, - (LongPointer) devTadShapeInfo, - dT, - (LongPointer) yDevTadShapeInfo, - yT); - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - } else if (ret.isScalar()) { - nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context)); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - } else { - nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, - (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets); - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - } + if (op instanceof Variance) { + if (ret.isScalar()) { + nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()), + ((Variance) op).isBiasCorrected()); } else { + nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + ((Variance) op).isBiasCorrected(), + (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets); + } + } else if (op.y() != null) { + if (op.isComplexAccumulation()) { + + val dT = new LongPointerWrapper(devTadOffsets); + val yT = new LongPointerWrapper(yDevTadOffsets); + + nativeOps.execReduce3All(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + (LongPointer) devTadShapeInfo, dT, + (LongPointer) yDevTadShapeInfo, yT); + } else if (ret.isScalar()) { + nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context)); + } else { + nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets); + } + } else { if (ret.isScalar()) { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo,(LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); break; case REDUCE_BOOL: nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); break; case REDUCE_LONG: nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); break; case REDUCE_SAME: nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); + z, (LongPointer) hostZShapeInfo,(LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); break; default: throw new UnsupportedOperationException(); } - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } else { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException(); } - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } } @@ -610,7 +566,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - val context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); @@ -619,10 +575,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); - val x = AtomicAllocator.getInstance().getPointer(op.x(), context); val xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); - - val z = AtomicAllocator.getInstance().getPointer(op.z(), context); val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension); @@ -644,22 +597,19 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer dimensionPointer = AtomicAllocator.getInstance() .getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); - nativeOps.execIndexReduce(xShapeInfoHostPointer, - op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - profilingConfigurableHookOut(op, st); return op.z(); @@ -681,7 +631,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { super.exec(op); if (op.z() != null) - AtomicAllocator.getInstance().tickHostWrite(op.z()); + throw new UnsupportedOperationException("Pew-pew"); + //AtomicAllocator.getInstance().tickHostWrite(op.z()); + return null; } @@ -731,12 +683,11 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (extraz.get() == null) extraz.set(new PointerPointer(32)); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); @@ -780,35 +731,32 @@ public class CudaExecutioner extends DefaultOpExecutioner { devTadShapeInfoZ, // 12 devTadOffsetsZ); // 13 - Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context); Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); - Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context); Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + //log.info("X: {}; Y: {}; Z: {}; dTS: {}, dTO: {}; dTSz: {}; dTOz: {};", x.address(), y.address(), z.address(), devTadShapeInfo.address(), devTadOffsets.address(), devTadShapeInfoZ.address(), devTadOffsetsZ.address()); switch (op.getOpType()) { case BROADCAST: nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case BROADCAST_BOOL: nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + null, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException("Unknown opType: " + op.getOpType()); @@ -817,8 +765,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - profilingConfigurableHookOut(op, st); return null; @@ -851,12 +797,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (dimension[i] >= op.x().rank() && dimension[i] != Integer.MAX_VALUE) throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]"); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z().isScalar() ? null : op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); - Pointer extraArgs = op.extraArgs() != null - ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null; + Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null; val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); @@ -873,9 +817,12 @@ public class CudaExecutioner extends DefaultOpExecutioner { DataBuffer offsets = tadBuffers.getSecond(); Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); - val z = AtomicAllocator.getInstance().getPointer(op.z(), context); val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + PointerPointer xShapeInfoHostPointer = extraz.get().put( AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), @@ -884,28 +831,22 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (op.z().isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) { nativeOps.execIndexReduceScalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); - - AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y()); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); } else { - Arrays.sort(dimension); + if (dimension != null && dimension.length > 1) + Arrays.sort(dimension); //long dimensionPointer = AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context); Pointer dimensionPointer = AtomicAllocator.getInstance() .getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension)); nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - dimensionPointer, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); - - AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y()); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); } if (nativeOps.lastErrorCode() != 0) @@ -919,7 +860,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { protected CudaContext invoke(ReduceOp op, int[] dimension) { - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] @@ -962,7 +903,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (dimension == null ) dimension = new int[] {Integer.MAX_VALUE}; - if (dimension.length > 1) + if (dimension != null && dimension.length > 1) Arrays.sort(dimension); for (int i = 0; i < dimension.length; i++) @@ -981,7 +922,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { val offsets = op.x().isEmpty() ? null : tadBuffers.getSecond(); val devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); long[] retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims()); @@ -1044,139 +984,114 @@ public class CudaExecutioner extends DefaultOpExecutioner { xShapeInfoHostPointer.put(13, yDevTadOffsets); } - val z = AtomicAllocator.getInstance().getPointer(op.z(), context); val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); - //log.info("Op.X address: {};", x.address()); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); op.validateDataTypes(); if (op.z().isScalar()) { if (op instanceof Variance) { nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((Variance) op).isBiasCorrected()); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } else if (op.y() != null) { - Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context); Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); } else { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_BOOL: nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_SAME: nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_LONG: nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; default: throw new UnsupportedOperationException(); } - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } } else { val dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); //AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context); if (op.y() != null) { - val y = AtomicAllocator.getInstance().getPointer(op.y(), context); val yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - dimensionPointer, null, (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets); + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets); } else { if (op instanceof Variance) { nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, ((Variance) op).isBiasCorrected(), - (LongPointer) devTadShapeInfo, - (LongPointer) devTadOffsets); + (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets); } else { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException(); } } } - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } if (nativeOps.lastErrorCode() != 0) @@ -1193,9 +1108,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { protected CudaContext intercept(ScalarOp op, int[] dimension) { long st = profilingConfigurableHookIn(op); - Arrays.sort(dimension); + if (dimension != null && dimension.length > 1) + Arrays.sort(dimension); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); @@ -1204,9 +1120,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); - val x = AtomicAllocator.getInstance().getPointer(op.x(), context); - val y = AtomicAllocator.getInstance().getPointer(op.y(), context); - val z = AtomicAllocator.getInstance().getPointer(op.z(), context); val xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); val yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); @@ -1239,30 +1152,28 @@ public class CudaExecutioner extends DefaultOpExecutioner { val dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + switch (op.getOpType()) { case SCALAR: nativeOps.execScalarTad(extraPointers, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, extraArgs, - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) devTadShapeInfoZ, (LongPointer) devTadOffsetsZ); break; case SCALAR_BOOL: nativeOps.execScalarBoolTad(extraPointers, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, extraArgs, - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) devTadShapeInfoZ, (LongPointer) devTadOffsetsZ); break; @@ -1273,8 +1184,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y()); - profilingConfigurableHookOut(op, st); return null; @@ -1293,6 +1202,19 @@ public class CudaExecutioner extends DefaultOpExecutioner { // validateDataType(Nd4j.dataType(), op); + if(op.z() == null){ + switch (op.getOpType()) { + case SCALAR: + op.setZ(op.x().ulike()); + break; + case SCALAR_BOOL: + op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape())); + break; + default: + throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); + } + } + if (op.x().length() != op.z().length()) throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != [" @@ -1309,17 +1231,15 @@ public class CudaExecutioner extends DefaultOpExecutioner { return null; } - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); val hostYShapeInfo = op.scalar() == null ? null : AddressRetriever.retrieveHostPointer(op.scalar().shapeInfoDataBuffer()); val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? op.x().dataType() : op.z().dataType()), context) : null; - Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context); Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); PointerPointer xShapeInfoHostPointer = extraz.get().put( @@ -1328,19 +1248,23 @@ public class CudaExecutioner extends DefaultOpExecutioner { context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, null, null); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.scalar() == null ? null : ((BaseCudaDataBuffer) op.scalar().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + switch (op.getOpType()) { case SCALAR_BOOL: nativeOps.execScalarBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.scalar(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs); break; case SCALAR: nativeOps.execScalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.scalar(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs); break; default: @@ -1350,8 +1274,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.scalar()); - profilingConfigurableHookOut(op, st); return null; @@ -1369,7 +1291,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (extraz.get() == null) extraz.set(new PointerPointer(32)); - CudaContext context = allocator.getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = allocator.getDeviceContext(); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); @@ -1377,7 +1299,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { // special temp array for IsMax along dimension INDArray ret = null; - Pointer x = allocator.getPointer(op.x(), context); Pointer xShapeInfo = allocator.getPointer(op.x().shapeInfoDataBuffer(), context); @@ -1413,7 +1334,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { op.validateDataTypes(experimentalMode.get()); - Pointer z = allocator.getPointer(op.z(), context); Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context); @@ -1440,31 +1360,30 @@ public class CudaExecutioner extends DefaultOpExecutioner { retHostShape); - + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); if (op.y() != null) { - Pointer y = allocator.getPointer(op.y(), context); Pointer yShapeInfo = allocator.getPointer(op.y().shapeInfoDataBuffer(), context); if (op.x().length() != op.y().length() || op.x().length() != op.z().length()) throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform"); - ///log.info("X: {}; Y: {}; Z: {}; E: {};", x.address(), y.address(), z.address(), extraArgs != null ? extraArgs.address() : null); - switch (op.getOpType()) { case TRANSFORM_BOOL: case PAIRWISE_BOOL: nativeOps.execPairwiseTransformBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; default: nativeOps.execPairwiseTransform(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; } @@ -1472,32 +1391,32 @@ public class CudaExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case TRANSFORM_ANY: nativeOps.execTransformAny(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_FLOAT: nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_BOOL: nativeOps.execTransformBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_SAME: nativeOps.execTransformSame(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_STRICT: nativeOps.execTransformStrict(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; default: @@ -1508,8 +1427,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - if (extraArgs != null) extraArgs.address(); @@ -1530,146 +1447,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void exec(Batch batch) { - val surfaceBuffer = (BaseCudaDataBuffer) getBuffer(batch); - surfaceBuffer.lazyAllocateHostPointer(); - - val context = AtomicAllocator.getInstance().getDeviceContext(); - - val pointer = (IntPointer) new CudaPointer(AtomicAllocator.getInstance().getHostPointer(surfaceBuffer)) - .asIntPointer(); - val surfacePoint = AtomicAllocator.getInstance().getAllocationPoint(surfaceBuffer); - - int maxTypes = 5; - - int maxIntArrays = batch.getSample().maxIntArrays(); - - int maxArraySize = batch.getSample().maxIntArraySize(); - - - int indexPos = maxTypes * (Batch.getBatchLimit() * 16); - int intArraysPos = indexPos + (batch.getSample().maxIndexArguments() * (Batch.getBatchLimit() * 16)); - int realPos = (intArraysPos + (maxIntArrays * maxArraySize * (Batch.getBatchLimit() * 16))) - / (Nd4j.dataType() == DataType.DOUBLE ? 2 : 1); - - if (Nd4j.dataType() == DataType.HALF) - realPos *= 2; - - int argsPos = (realPos + (batch.getSample().maxRealArguments() * (Batch.getBatchLimit() * 16))) - / (Nd4j.dataType() == DataType.FLOAT ? 2 : 1); - - if (Nd4j.dataType() == DataType.HALF) - argsPos /= 4; - - int shapesPos = argsPos + (batch.getSample().maxArguments() * (Batch.getBatchLimit() * 16)); - DataType dataType = null; - for (int i = 0; i < batch.getNumAggregates(); i++) { - T op = batch.getAggregates().get(i); - - if (i == 0) - dataType = op.getArguments().get(0).dataType(); - - // put num arguments - int idx = i * maxTypes; - pointer.put(idx, op.getArguments().size()); - pointer.put(idx + 1, op.getShapes().size()); - pointer.put(idx + 2, op.getIndexingArguments().size()); - pointer.put(idx + 3, op.getRealArguments().size()); - pointer.put(idx + 4, op.getIntArrayArguments().size()); - - - // putting indexing arguments - for (int e = 0; e < op.getIndexingArguments().size(); e++) { - idx = indexPos + i * batch.getSample().maxIndexArguments(); - pointer.put(idx + e, op.getIndexingArguments().get(e)); - } - - // putting intArray values - int bsize = maxIntArrays * maxArraySize; - for (int e = 0; e < op.getIntArrayArguments().size(); e++) { - int step = (i * bsize) + (e * maxArraySize); - if (op.getIntArrayArguments().get(e) != null) - for (int x = 0; x < op.getIntArrayArguments().get(e).length; x++) { - idx = intArraysPos + step + x; - pointer.put(idx, op.getIntArrayArguments().get(e)[x]); - } - } - - // TODO: variable datatype should be handled here - // putting real arguments - switch (dataType) { - case FLOAT: { - FloatPointer realPtr = new FloatPointer(pointer); - for (int e = 0; e < op.getRealArguments().size(); e++) { - idx = realPos + i * op.maxRealArguments(); - realPtr.put(idx + e, op.getRealArguments().get(e).floatValue()); - } - } - break; - case DOUBLE: { - DoublePointer dPtr = new DoublePointer(pointer); - for (int e = 0; e < op.getRealArguments().size(); e++) { - idx = realPos + (i * op.maxRealArguments()); - dPtr.put(idx + e, op.getRealArguments().get(e).doubleValue()); - } - } - break; - case HALF: { - ShortPointer sPtr = new ShortPointer(pointer); - for (int e = 0; e < op.getRealArguments().size(); e++) { - idx = realPos + (i * op.maxRealArguments()); - sPtr.put(idx + e, BaseDataBuffer.fromFloat(op.getRealArguments().get(e).floatValue())); - } - } - break; - default: - throw new UnsupportedOperationException("Unknown data type"); - } - - // putting arguments pointers - PointerPointer ptrPtr = new PointerPointer(pointer); - for (int e = 0; e < op.getArguments().size(); e++) { - idx = argsPos + i * batch.getSample().maxArguments(); - - if (op.getArguments().get(e) != null) { - ptrPtr.put(idx + e, AtomicAllocator.getInstance().getPointer(op.getArguments().get(e), context)); - AtomicAllocator.getInstance().getAllocationPoint(op.getArguments().get(e)).tickDeviceWrite(); - } - } - - - // putting shape pointers - for (int e = 0; e < op.getShapes().size(); e++) { - idx = shapesPos + i * batch.getSample().maxShapes(); - - if (op.getShapes().get(e) != null) { - ptrPtr.put(idx + e, AtomicAllocator.getInstance().getPointer(op.getShapes().get(e), context)); - AtomicAllocator.getInstance().getAllocationPoint(op.getShapes().get(e)).tickDeviceWrite(); - } - } - } - - // trigger write, so getPointer request will force relocation to GPU - surfacePoint.tickHostWrite(); - - PointerPointer extraArgs = new PointerPointer(32); - extraArgs.put(0, null); - extraArgs.put(1, context.getOldStream()); - extraArgs.put(2, new CudaPointer(Math.min(batch.getNumAggregates(), - CudaEnvironment.getInstance().getConfiguration().getMaximumGridSize()))); - extraArgs.put(3, new CudaPointer(batch.getSample().getThreadsPerInstance())); - extraArgs.put(4, new CudaPointer(batch.getSample().getSharedMemorySize())); - - - nativeOps.execAggregateBatch(extraArgs, batch.getNumAggregates(), batch.opNum(), - batch.getSample().maxArguments(), batch.getSample().maxShapes(), - batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), - batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), - AtomicAllocator.getInstance().getPointer(surfaceBuffer, context), FlatBuffersMapper.getDataTypeAsByte(dataType)); - - if (nativeOps.lastErrorCode() != 0) - throw new RuntimeException(nativeOps.lastErrorMessage()); - - surfacePoint.tickHostWrite(); + throw new UnsupportedOperationException("Pew-pew"); } @Override @@ -1688,84 +1466,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void exec(Aggregate op) { - int numArguments = op.getArguments().size(); - int numShapeArguments = op.getShapes().size(); - int numIndexArguments = op.getIndexingArguments().size(); - int numIntArrays = op.getIntArrayArguments().size(); - int numRealArguments = op.getRealArguments().size(); - - val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); - - val extraArgs = new PointerPointer(32); - extraArgs.put(0, null); - extraArgs.put(1, context.getOldStream()); - extraArgs.put(2, new CudaPointer(1)); - extraArgs.put(3, new CudaPointer(op.getThreadsPerInstance())); - extraArgs.put(4, new CudaPointer(op.getSharedMemorySize())); - - long arguments[] = new long[numArguments]; - val dataType = op.getArguments().get(0).dataType(); - - for (int x = 0; x < numArguments; x++) { - arguments[x] = op.getArguments().get(x) == null ? 0 - : AtomicAllocator.getInstance().getPointer(op.getArguments().get(x), context).address(); - - if (op.getArguments().get(x) != null) - AtomicAllocator.getInstance().getAllocationPoint(op.getArguments().get(x)).tickDeviceWrite(); - } - - DataBuffer tempX = AllocationUtils.getPointersBuffer(arguments); - PointerPointer xPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)); - - - long shapes[] = new long[numShapeArguments]; - for (int x = 0; x < numShapeArguments; x++) { - shapes[x] = op.getShapes().get(x) == null ? 0 - : AtomicAllocator.getInstance().getPointer(op.getShapes().get(x), context).address(); - - if (op.getShapes().get(x) != null) - AtomicAllocator.getInstance().getAllocationPoint(op.getShapes().get(x)).tickDeviceWrite(); - } - - DataBuffer tempS = AllocationUtils.getPointersBuffer(shapes); - PointerPointer sPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempS, context)); - - - long ints[] = new long[numIntArrays]; - for (int x = 0; x < numIntArrays; x++) { - if (op.getIntArrayArguments().get(x) != null) { - DataBuffer intBuf = Nd4j.getDataBufferFactory().createInt(op.getIntArrayArguments().get(x)); - ints[x] = AtomicAllocator.getInstance().getPointer(intBuf, context).address(); - } - - } - - DataBuffer tempI = AllocationUtils.getPointersBuffer(ints); - PointerPointer iPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempI, context)); - - int[] indexes = new int[numIndexArguments]; - for (int x = 0; x < numIndexArguments; x++) { - indexes[x] = op.getIndexingArguments().get(x); - } - - DataBuffer intBuffer = Nd4j.getDataBufferFactory().createInt(indexes); - - double[] reals = new double[numRealArguments]; - INDArray realsBuffer; - for (int x = 0; x < numRealArguments; x++) { - reals[x] = op.getRealArguments().get(x).doubleValue(); - } - - realsBuffer = Nd4j.create(reals, new long[]{reals.length}, dataType); - - nativeOps.execAggregate(extraArgs, op.opNum(), xPtr, numArguments, sPtr, numShapeArguments, - (IntPointer) AtomicAllocator.getInstance().getPointer(intBuffer, context), - numIndexArguments, iPtr, numIntArrays, - AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), - numRealArguments, FlatBuffersMapper.getDataTypeAsByte(dataType)); - - if (nativeOps.lastErrorCode() != 0) - throw new RuntimeException(nativeOps.lastErrorMessage()); + throw new UnsupportedOperationException("Pew-pew"); } /** @@ -1797,7 +1498,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); PointerPointer extraZZ = extraz.get().put(AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()); @@ -1806,34 +1507,36 @@ public class CudaExecutioner extends DefaultOpExecutioner { val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + if (op.x() != null && op.y() != null && op.z() != null) { // triple arg call nativeOps.execRandom3(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - null, (LongPointer) hostXShapeInfo, AtomicAllocator.getInstance().getPointer(op.x(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + x, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context)); } else if (op.x() != null && op.z() != null) { //double arg call nativeOps.execRandom2(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - null, (LongPointer) hostXShapeInfo, AtomicAllocator.getInstance().getPointer(op.x(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + x, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()),context)); } else { // single arg call nativeOps.execRandom(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context)); } if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y()); - profilingConfigurableHookOut(op, st); return op.z(); @@ -1931,6 +1634,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { val extras = extraz.get().put(1, context.getOldStream()); + ((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer().syncToSpecial(); NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1(extras, @@ -1954,22 +1658,16 @@ public class CudaExecutioner extends DefaultOpExecutioner { blocksBuffer.put(0, numMatches); } -/* - log.info("Totals: {}", numMatches); - - - log.info("Number of blocks for compression: {}", numBlocks); - log.info("BlocksCounts: {}", Arrays.toString(blocksBuffer.asInt())); -*/ DataBuffer encodedBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(4+numMatches, false) : Nd4j.getDataBufferFactory().createInt(4+numMatches, false, Nd4j.getMemoryManager().getCurrentWorkspace()); - AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickHostWrite(); + encodedBuffer.put(0, numMatches); encodedBuffer.put(1, (int) buffer.length()); encodedBuffer.put(2, Float.floatToIntBits((float) threshold)); - AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickHostWrite(); encodedBuffer.put(3, ThresholdCompression.FLEXIBLE_ENCODING); + ((BaseCudaDataBuffer) encodedBuffer).getOpaqueDataBuffer().syncToSpecial(); + int prefixThreads = 512; int numElts = numBlocks; @@ -2082,7 +1780,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { // format id buffer.put(3, ThresholdCompression.BITMAP_ENCODING); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(indArray); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (extraz.get() == null) extraz.set(new PointerPointer(32)); @@ -2095,17 +1793,20 @@ public class CudaExecutioner extends DefaultOpExecutioner { context.getBufferReduction() ); + + val src = AtomicAllocator.getInstance().getPointer(indArray, context); + val dst = (IntPointer) AtomicAllocator.getInstance().getPointer(buffer, context); + ((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer().syncToSpecial(); + long val = nativeOps.encodeBitmap(extras, - AtomicAllocator.getInstance().getPointer(indArray, context), (LongPointer) AtomicAllocator.getInstance().getHostPointer(indArray.shapeInfoDataBuffer()), + src, (LongPointer) AtomicAllocator.getInstance().getHostPointer(indArray.shapeInfoDataBuffer()), length, - (IntPointer) AtomicAllocator.getInstance().getPointer(buffer, context), + dst, (float) threshold); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().getFlowController().registerAction(context, indArray); - AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite(); return val; @@ -2114,7 +1815,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray bitmapDecode(INDArray encoded, INDArray target) { - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(target); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (extraz.get() == null) extraz.set(new PointerPointer(32)); @@ -2131,8 +1832,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().getFlowController().registerAction(context, target); - return target; } @@ -2207,15 +1906,15 @@ public class CudaExecutioner extends DefaultOpExecutioner { return Collections.emptyList(); } - val inputBuffers = new PointerPointer<>(op.inputArguments().length * 2); - val inputShapes = new PointerPointer<>(op.inputArguments().length); + val inputBuffers = new PointerPointer<>(op.inputArguments().size() * 2); + val inputShapes = new PointerPointer<>(op.inputArguments().size()); int cnt= 0; for (val in: op.inputArguments()) { // NOT A TYPO: shape functions work on host side only if (!in.isEmpty()) { inputBuffers.put(cnt, in.data().addressPointer()); - inputBuffers.put(cnt + op.inputArguments().length, AtomicAllocator.getInstance().getPointer(in.data())); + inputBuffers.put(cnt + op.inputArguments().size(), AtomicAllocator.getInstance().getPointer(in.data())); } inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer()); @@ -2240,7 +1939,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { for (val t: op.tArgs()) tArgs.put(cnt++, t); - OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); + OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); @@ -2269,7 +1968,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { Nd4j.getExecutioner().commit(); - // + boolean shapeOverride = false; if (op.numOutputArguments() == 0 && !op.isInplaceCall()) { try { val list = this.calculateOutputShape(op); @@ -2279,8 +1978,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { for (val shape: list) op.addOutputArgument(Nd4j.create(shape)); + shapeOverride = true; } catch (Exception e) { - throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified"); + throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e); } } @@ -2289,6 +1989,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { val name = op.opName(); try (val context = (CudaOpContext) buildContext()) { + // optionally skip shape validation on op execution + if (shapeOverride) + context.shapeFunctionOverride(true); + context.markInplace(op.isInplaceCall()); // transferring rng state @@ -2306,6 +2010,17 @@ public class CudaExecutioner extends DefaultOpExecutioner { val result = exec(op, context); val states = context.getRngStates(); + // check if input && output needs update + for (val in:op.inputArguments()) { + if (!in.isEmpty()) + ((BaseCudaDataBuffer) in.data()).actualizePointerAndIndexer(); + } + + for (val out:op.outputArguments()) { + if (!out.isEmpty()) + ((BaseCudaDataBuffer) out.data()).actualizePointerAndIndexer(); + } + // pulling states back Nd4j.getRandom().setStates(states.getFirst(), states.getSecond()); @@ -2389,7 +2104,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { val array = Nd4j.create(shapeOf, stridesOf, 0, order); Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, ArrayUtil.prod(shapeOf) * Nd4j.sizeOfDataType()); - AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite(); + //AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite(); + if (1 > 0) + throw new UnsupportedOperationException("Pew-pew"); String nodeName = nativeOps.getVariableName(var); newMap.put(nodeName, array); @@ -2445,7 +2162,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { } @Override - public String getString(Utf8Buffer buffer, long index) { + public String getString(DataBuffer buffer, long index) { + Preconditions.checkArgument(buffer instanceof CudaUtf8Buffer, "Expected Utf8Buffer"); + val addr = ((LongIndexer) buffer.indexer()).get(index); val ptr = new PagedPointer(addr); val str = new Nd4jCuda.utf8string(ptr); @@ -2459,7 +2178,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void scatterUpdate(ScatterUpdate.UpdateOp op, @NonNull INDArray array, @NonNull INDArray indices, @NonNull INDArray updates, @NonNull int[] axis) { - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(array, indices, updates); + val context = AtomicAllocator.getInstance().getDeviceContext(); val tadX = tadManager.getTADOnlyShapeInfo(array, axis); val tadY = tadManager.getTADOnlyShapeInfo(updates, axis); @@ -2479,8 +2198,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - - AtomicAllocator.getInstance().getFlowController().registerAction(context, array, indices, updates); } @Override @@ -2502,13 +2219,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (status != 0) throw new RuntimeException("Op [" + op.opName() + "] execution failed"); - - - for (val arr:op.outputArguments()) - AtomicAllocator.getInstance().registerAction(ctx, arr); - - AtomicAllocator.getInstance().registerAction(ctx, null, op.inputArguments()); - profilingConfigurableHookOut(op, st); if (context.getOutputArrays().isEmpty()) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index b75f688fe..6f37be02a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.primitives.Pair; import org.nd4j.nativeblas.NativeOps; @@ -88,8 +89,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext { @Override public void setInputArray(int index, @NonNull INDArray array) { val ctx = AtomicAllocator.getInstance().getFlowController().prepareAction(null, array); - - nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); + nativeOps.setGraphContextInputBuffer(context, index, array.isEmpty() ? null : ((BaseCudaDataBuffer) array.data()).getOpaqueDataBuffer(), array.shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); super.setInputArray(index, array); } @@ -97,33 +97,13 @@ public class CudaOpContext extends BaseOpContext implements OpContext { @Override public void setOutputArray(int index, @NonNull INDArray array) { val ctx = AtomicAllocator.getInstance().getFlowController().prepareAction(array, null); - - nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); + nativeOps.setGraphContextOutputBuffer(context, index, array.isEmpty() ? null : ((BaseCudaDataBuffer) array.data()).getOpaqueDataBuffer(), array.shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); super.setOutputArray(index, array); } @Override public Pointer contextPointer() { - for (val v:fastpath_in.values()) { - if (v.isEmpty() || v.isS()) - continue; - - AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead(); - AtomicAllocator.getInstance().getAllocationPoint(v).tickDeviceRead(); - - //if (context.isInplace()) - //AtomicAllocator.getInstance().getAllocationPoint(v).tickDeviceWrite(); - } - - for (val v:fastpath_out.values()) { - if (v.isEmpty() || v.isS()) - continue; - - AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead(); - AtomicAllocator.getInstance().getAllocationPoint(v).tickDeviceRead(); - } - return context; } @@ -141,4 +121,9 @@ public class CudaOpContext extends BaseOpContext implements OpContext { public void allowHelpers(boolean reallyAllow) { nativeOps.ctxAllowHelpers(context, reallyAllow); } + + @Override + public void shapeFunctionOverride(boolean reallyOverride) { + nativeOps.ctxShapeFunctionOverride(context, reallyOverride); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java index 6e2d8ebf0..68fff737a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java @@ -54,6 +54,10 @@ public class CudaNativeRandom extends NativeRandom { public void init() { nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + setSeed(seed); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/CudaEnvironment.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/CudaEnvironment.java new file mode 100644 index 000000000..16abeef9a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/CudaEnvironment.java @@ -0,0 +1,195 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.nativeblas; + +import org.nd4j.linalg.factory.Environment; +import org.nd4j.nativeblas.Nd4jCuda; + +/** + * CUDA backend implementation of {@link Environment} + * + * @author Alex Black + */ +public class CudaEnvironment implements Environment { + + + private static final CudaEnvironment INSTANCE = new CudaEnvironment(Nd4jCuda.Environment.getInstance()); + + private final Nd4jCuda.Environment e; + + public static CudaEnvironment getInstance(){ + return INSTANCE; + } + + protected CudaEnvironment(Nd4jCuda.Environment environment){ + this.e = environment; + } + + @Override + public int blasMajorVersion() { + return e.blasMajorVersion(); + } + + @Override + public int blasMinorVersion() { + return e.blasMinorVersion(); + } + + @Override + public int blasPatchVersion() { + return e.blasMajorVersion(); + } + + @Override + public boolean isVerbose() { + return e.isVerbose(); + } + + @Override + public void setVerbose(boolean reallyVerbose) { + e.setVerbose(reallyVerbose); + } + + @Override + public boolean isDebug() { + return e.isDebug(); + } + + @Override + public boolean isProfiling() { + return e.isProfiling(); + } + + @Override + public boolean isDetectingLeaks() { + return e.isDetectingLeaks(); + } + + @Override + public boolean isDebugAndVerbose() { + return e.isDebugAndVerbose(); + } + + @Override + public void setDebug(boolean reallyDebug) { + e.setDebug(reallyDebug); + } + + @Override + public void setProfiling(boolean reallyProfile) { + e.setProfiling(reallyProfile); + } + + @Override + public void setLeaksDetector(boolean reallyDetect) { + e.setLeaksDetector(reallyDetect); + } + + @Override + public boolean helpersAllowed() { + return e.helpersAllowed(); + } + + @Override + public void allowHelpers(boolean reallyAllow) { + e.allowHelpers(reallyAllow); + } + + @Override + public int tadThreshold() { + return e.tadThreshold(); + } + + @Override + public void setTadThreshold(int threshold) { + e.setTadThreshold(threshold); + } + + @Override + public int elementwiseThreshold() { + return e.elementwiseThreshold(); + } + + @Override + public void setElementwiseThreshold(int threshold) { + e.setElementwiseThreshold(threshold); + } + + @Override + public int maxThreads() { + return e.maxThreads(); + } + + @Override + public void setMaxThreads(int max) { + e.setMaxThreads(max); + } + + @Override + public int maxMasterThreads() { + return e.maxMasterThreads(); + } + + @Override + public void setMaxMasterThreads(int max) { + e.setMaxMasterThreads(max); + } + + @Override + public void setMaxPrimaryMemory(long maxBytes) { + e.setMaxPrimaryMemory(maxBytes); + } + + @Override + public void setMaxSpecialMemory(long maxBytes) { + e.setMaxSpecialyMemory(maxBytes); + } + + @Override + public void setMaxDeviceMemory(long maxBytes) { + e.setMaxDeviceMemory(maxBytes); + } + + @Override + public boolean isCPU() { + return e.isCPU(); + } + + @Override + public void setGroupLimit(int group, long numBytes) { + e.setGroupLimit(group, numBytes); + } + + @Override + public void setDeviceLimit(int deviceId, long numBytes) { + e.setDeviceLimit(deviceId, numBytes); + } + + @Override + public long getGroupLimit(int group) { + return e.getGroupLimit(group); + } + + @Override + public long getDeviceLimit(int deviceId) { + return e.getDeviceLimit(deviceId); + } + + @Override + public long getDeviceCouner(int deviceId) { + return e.getDeviceCounter(deviceId); + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 94c5601c1..5aa685c7a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -175,6 +175,74 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { } } +@Name("std::vector") public static class ConstNDArrayVector extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ConstNDArrayVector(Pointer p) { super(p); } + public ConstNDArrayVector(NDArray value) { this(1); put(0, value); } + public ConstNDArrayVector(NDArray ... array) { this(array.length); put(array); } + public ConstNDArrayVector() { allocate(); } + public ConstNDArrayVector(long n) { allocate(n); } + private native void allocate(); + private native void allocate(@Cast("size_t") long n); + public native @Name("operator=") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x); + + public boolean empty() { return size() == 0; } + public native long size(); + public void clear() { resize(0); } + public native void resize(@Cast("size_t") long n); + + @Index(function = "at") public native @Const NDArray get(@Cast("size_t") long i); + public native ConstNDArrayVector put(@Cast("size_t") long i, NDArray value); + + public native @ByVal Iterator insert(@ByVal Iterator pos, @Const NDArray value); + public native @ByVal Iterator erase(@ByVal Iterator pos); + public native @ByVal Iterator begin(); + public native @ByVal Iterator end(); + @NoOffset @Name("iterator") public static class Iterator extends Pointer { + public Iterator(Pointer p) { super(p); } + public Iterator() { } + + public native @Name("operator++") @ByRef Iterator increment(); + public native @Name("operator==") boolean equals(@ByRef Iterator it); + public native @Name("operator*") @Const NDArray get(); + } + + public NDArray[] get() { + NDArray[] array = new NDArray[size() < Integer.MAX_VALUE ? (int)size() : Integer.MAX_VALUE]; + for (int i = 0; i < array.length; i++) { + array[i] = get(i); + } + return array; + } + @Override public String toString() { + return java.util.Arrays.toString(get()); + } + + public NDArray pop_back() { + long size = size(); + NDArray value = get(size - 1); + resize(size - 1); + return value; + } + public ConstNDArrayVector push_back(NDArray value) { + long size = size(); + resize(size + 1); + return put(size, value); + } + public ConstNDArrayVector put(NDArray value) { + if (size() != 1) { resize(1); } + return put(0, value); + } + public ConstNDArrayVector put(NDArray ... array) { + if (size() != array.length) { resize(array.length); } + for (int i = 0; i < array.length; i++) { + put(i, array[i]); + } + return this; + } +} + @NoOffset @Name("std::pair") public static class IntIntPair extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -240,12 +308,167 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { QINT16 = 16, BFLOAT16 = 17, UTF8 = 50, + UTF16 = 51, + UTF32 = 52, ANY = 100, AUTO = 200; // #endif +// Parsed from array/DataBuffer.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +// #ifndef DEV_TESTS_DATABUFFER_H +// #define DEV_TESTS_DATABUFFER_H + +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +@Namespace("nd4j") @NoOffset public static class DataBuffer extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DataBuffer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public DataBuffer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public DataBuffer position(long position) { + return (DataBuffer)super.position(position); + } + + + public DataBuffer(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType, isOwnerPrimary, isOwnerSpecial, workspace); } + private native void allocate(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType); } + private native void allocate(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, lenInBytes, dataType, isOwnerPrimary, workspace); } + private native void allocate(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(primary, lenInBytes, dataType); } + private native void allocate(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes, workspace); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/); + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes); + + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/) { super((Pointer)null); allocate(lenInBytes, dataType, workspace, allocBoth); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/); + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(lenInBytes, dataType); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(@Const @ByRef DataBuffer other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef DataBuffer other); + public DataBuffer() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native @ByRef @Name("operator =") DataBuffer put(@Const @ByRef DataBuffer other); + + public native @Cast("nd4j::DataType") int getDataType(); + public native void setDataType(@Cast("nd4j::DataType") int dataType); + public native @Cast("size_t") long getLenInBytes(); + + public native Pointer primary(); + public native Pointer special(); + + public native void allocatePrimary(); + public native void allocateSpecial(); + + public native void writePrimary(); + public native void writeSpecial(); + public native void readPrimary(); + public native void readSpecial(); + public native @Cast("bool") boolean isPrimaryActual(); + public native @Cast("bool") boolean isSpecialActual(); + + public native void expand(@Cast("const uint64_t") long size); + + public native int deviceId(); + public native void setDeviceId(int deviceId); + public native void migrate(); + + public native void syncToPrimary(@Const LaunchContext context, @Cast("const bool") boolean forceSync/*=false*/); + public native void syncToPrimary(@Const LaunchContext context); + public native void syncToSpecial(@Cast("const bool") boolean forceSync/*=false*/); + public native void syncToSpecial(); + + public native void setToZeroBuffers(@Cast("const bool") boolean both/*=false*/); + public native void setToZeroBuffers(); + + public native void copyBufferFrom(@Const @ByRef DataBuffer other, @Cast("size_t") long sizeToCopyinBytes/*=0*/, @Cast("const Nd4jLong") long offsetThis/*=0*/, @Cast("const Nd4jLong") long offsetOther/*=0*/); + public native void copyBufferFrom(@Const @ByRef DataBuffer other); + + public static native void memcpy(@Const @ByRef DataBuffer dst, @Const @ByRef DataBuffer src); + + public native void setPrimaryBuffer(Pointer buffer, @Cast("size_t") long length); + public native void setSpecialBuffer(Pointer buffer, @Cast("size_t") long length); + + /** + * This method deletes buffers, if we're owners + */ + public native @Name("close") void _close(); +} +///// IMLEMENTATION OF INLINE METHODS ///// + +//////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////// + + + + + +// #endif //DEV_TESTS_DATABUFFER_H + + // Parsed from array/ConstantDescriptor.h /******************************************************************************* @@ -272,7 +495,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { // #define DEV_TESTS_CONSTANTDESCRIPTOR_H // #include -// #include +// #include // #include // #include // #include @@ -503,6 +726,39 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { // #endif //DEV_TESTS_ERRORREFERENCE_H +// Parsed from execution/Engine.h + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_ENGINE_H +// #define SD_ENGINE_H + /** enum samediff::Engine */ + public static final int + ENGINE_CPU = 0, + ENGINE_CUDA = 1; + + +// #endif //SD_ENGINE_H + + // Parsed from memory/MemoryType.h // @@ -552,6 +808,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { // #include // #include // #include +// #include @Namespace("nd4j") @NoOffset public static class Environment extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -590,10 +847,30 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public native int maxMasterThreads(); public native void setMaxMasterThreads(int max); + /* + * Legacy memory limits API, still used in new API as simplified version + */ public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes); public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes); public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes); + public native @Cast("uint64_t") long maxPrimaryMemory(); + public native @Cast("uint64_t") long maxSpecialMemory(); + //////////////////////// + + /* + * Methods for memory limits/counters + */ + public native void setGroupLimit(int group, @Cast("Nd4jLong") long numBytes); + public native void setDeviceLimit(int deviceId, @Cast("Nd4jLong") long numBytes); + + public native @Cast("Nd4jLong") long getGroupLimit(int group); + public native @Cast("Nd4jLong") long getDeviceLimit(int deviceId); + + public native @Cast("Nd4jLong") long getGroupCounter(int group); + public native @Cast("Nd4jLong") long getDeviceCounter(int deviceId); + //////////////////////// + public native @Cast("bool") boolean isUseMKLDNN(); public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); @@ -753,6 +1030,7 @@ bool verbose = false; // #include // #include // #include +// #include // #include // #include // #include @@ -760,6 +1038,7 @@ bool verbose = false; // #include // #include // #include +// #include /** * This function returns last error code stored, @@ -801,25 +1080,19 @@ public native void setTADThreshold(int num); */ public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -834,31 +1107,22 @@ public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer ex */ public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -875,74 +1139,50 @@ public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPoi public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -959,63 +1199,45 @@ public native void execBroadcastBool( public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); /** @@ -1029,92 +1251,68 @@ public native void execPairwiseTransformBool( */ public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -1127,118 +1325,82 @@ public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -1253,31 +1415,22 @@ public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPoi */ public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -1290,31 +1443,22 @@ public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointer */ public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * * @param opNum @@ -1330,82 +1474,58 @@ public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraP */ public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong*") LongPointer yTadOffsets); public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadOnlyShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer yTadOnlyShapeInfo, @Cast("Nd4jLong*") LongBuffer yTadOffsets); public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadOnlyShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] yTadOnlyShapeInfo, @Cast("Nd4jLong*") long[] yTadOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer xTadShapeInfo, @Cast("Nd4jLong*") LongPointer xOffsets, @Cast("Nd4jLong*") LongPointer yTadShapeInfo, @Cast("Nd4jLong*") LongPointer yOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer xTadShapeInfo, @Cast("Nd4jLong*") LongBuffer xOffsets, @Cast("Nd4jLong*") LongBuffer yTadShapeInfo, @Cast("Nd4jLong*") LongBuffer yOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] xTadShapeInfo, @Cast("Nd4jLong*") long[] xOffsets, @Cast("Nd4jLong*") long[] yTadShapeInfo, @Cast("Nd4jLong*") long[] yOffsets); @@ -1422,58 +1542,40 @@ public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, Pointer extraParams); public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, Pointer extraParams); public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, Pointer extraParams); /** @@ -1485,27 +1587,21 @@ public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * @@ -1518,27 +1614,21 @@ public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer e */ public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * @@ -1553,35 +1643,26 @@ public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPo */ public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets); public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets); public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets); @@ -1597,112 +1678,82 @@ public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extr */ public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); /** @@ -1720,81 +1771,57 @@ public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extr */ public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("Nd4jLong*") LongPointer tadOffsetsZ); public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("Nd4jLong*") LongBuffer tadOffsetsZ); public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") long[] dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, @Cast("Nd4jLong*") long[] dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("Nd4jLong*") LongPointer tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("Nd4jLong*") LongBuffer tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") long[] dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, @Cast("Nd4jLong*") long[] dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); @@ -2157,10 +2184,8 @@ public native void deleteTadPack(OpaqueTadPack ptr); * @param zTadOffsets */ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") LongPointer zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") LongPointer dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer zShapeInfo, @Cast("Nd4jLong*") LongPointer dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") LongPointer indexes, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @@ -2168,10 +2193,8 @@ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer zTadShapeInfo, @Cast("Nd4jLong*") LongPointer zTadOffsets); public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") LongBuffer zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") LongBuffer dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer zShapeInfo, @Cast("Nd4jLong*") LongBuffer dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") LongBuffer indexes, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @@ -2179,10 +2202,8 @@ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer zTadShapeInfo, @Cast("Nd4jLong*") LongBuffer zTadOffsets); public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") long[] zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") long[] dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jLong*") long[] dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] zShapeInfo, @Cast("Nd4jLong*") long[] dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") long[] indexes, @Cast("Nd4jLong*") long[] tadShapeInfo, @@ -2448,20 +2469,17 @@ public native void execAggregateBatch(@Cast("Nd4jPointer*") PointerPointer extra public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** @@ -2480,32 +2498,23 @@ public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeBuffer, @Cast("Nd4jLong*") LongPointer dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeBuffer, @Cast("Nd4jLong*") LongBuffer dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeBuffer, @Cast("Nd4jLong*") long[] dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeBuffer, @Cast("Nd4jLong*") long[] dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** @@ -2522,26 +2531,20 @@ public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointer public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeBuffer, @Cast("Nd4jLong*") long[] dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); @@ -2584,52 +2587,6 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe */ public native void destroyRandom(@Cast("Nd4jPointer") Pointer ptrRandom); -/** - * Grid operations - */ - - - - -/** - * - * @param extras - * @param opTypeA - * @param opNumA - * @param opTypeB - * @param opNumB - * @param N - * @param dx - * @param xShapeInfo - * @param dy - * @param yShapeInfo - * @param dz - * @param zShapeInfo - * @param extraA - * @param extraB - * @param scalarA - * @param scalarB - */ - /* -ND4J_EXPORT void execMetaPredicateShape(Nd4jPointer *extras, - const int opTypeA, - const int opNumA, - const int opTypeB, - const int opNumB, - Nd4jLong N, - void *hX, Nd4jLong *hXShapeBuffer, - void *dX, Nd4jLong *dXShapeBuffer, - void *hY, Nd4jLong *hYShapeBuffer, - void *dY, Nd4jLong *dYShapeBuffer, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, - void *extraA, - void *extraB, - double scalarA, - double scalarB); - -*/ - /** * * @param data @@ -2792,23 +2749,20 @@ public native @Cast("Nd4jPointer") Pointer pointerForAddress(@Cast("Nd4jLong") l * @return */ public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongPointer zShapeInfo, - @Cast("Nd4jLong*") LongPointer tadShapeInfo, - @Cast("Nd4jLong*") LongPointer tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongPointer zShapeInfo, + @Cast("Nd4jLong*") LongPointer tadShapeInfo, + @Cast("Nd4jLong*") LongPointer tadOffsets); public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongBuffer zShapeInfo, - @Cast("Nd4jLong*") LongBuffer tadShapeInfo, - @Cast("Nd4jLong*") LongBuffer tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongBuffer zShapeInfo, + @Cast("Nd4jLong*") LongBuffer tadShapeInfo, + @Cast("Nd4jLong*") LongBuffer tadOffsets); public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") long[] zShapeInfo, - @Cast("Nd4jLong*") long[] tadShapeInfo, - @Cast("Nd4jLong*") long[] tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jLong*") long[] dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") long[] zShapeInfo, + @Cast("Nd4jLong*") long[] tadShapeInfo, + @Cast("Nd4jLong*") long[] tadOffsets); public native @Cast("Nd4jLong") long encodeBitmap(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer dx, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong") long N, IntPointer dz, float threshold); public native @Cast("Nd4jLong") long encodeBitmap(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer dx, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong") long N, IntBuffer dz, float threshold); @@ -3100,10 +3054,13 @@ public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr); public native OpaqueContext createGraphContext(int nodeId); public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); +public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); +public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); +public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); @@ -3136,6 +3093,28 @@ public native @Cast("Nd4jPointer") Pointer lcCopyStream(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); +public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); +public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset); +public native @Cast("Nd4jPointer") Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer); +public native @Cast("Nd4jPointer") Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer); +public native void dbExpandBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); +public native void dbAllocatePrimaryBuffer(OpaqueDataBuffer dataBuffer); +public native void dbAllocateSpecialBuffer(OpaqueDataBuffer dataBuffer); +public native void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer primaryBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSyncToSpecial(OpaqueDataBuffer dataBuffer); +public native void dbSyncToPrimary(OpaqueDataBuffer dataBuffer); +public native int dbLocality(OpaqueDataBuffer dataBuffer); +public native int dbDeviceId(OpaqueDataBuffer dataBuffer); +public native void dbSetDeviceId(OpaqueDataBuffer dataBuffer, int deviceId); +public native void dbTickHostRead(OpaqueDataBuffer dataBuffer); +public native void dbTickHostWrite(OpaqueDataBuffer dataBuffer); +public native void dbTickDeviceRead(OpaqueDataBuffer dataBuffer); +public native void dbTickDeviceWrite(OpaqueDataBuffer dataBuffer); +public native void dbClose(OpaqueDataBuffer dataBuffer); +public native void deleteDataBuffer(OpaqueDataBuffer dataBuffer); +public native void dbExpand(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); + public native int binaryLevel(); public native int optimalLevel(); @@ -3632,27 +3611,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); // #include // #include // #include +// #include +// #include +// #include - @Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(float arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(@Cast("const float16") short arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(double arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(int arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator +") NDArray add(float arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator +") NDArray add(@Cast("const float16") short arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator +") NDArray add(double arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator +") NDArray add(int arg0, @Const @ByRef NDArray arg1); - - @Namespace("nd4j") public static native @ByVal @Name("operator *") NDArray multiply(float arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator *") NDArray multiply(@Cast("const float16") short arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator *") NDArray multiply(double arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator *") NDArray multiply(int arg0, @Const @ByRef NDArray arg1); - - @Namespace("nd4j") public static native @ByVal @Name("operator /") NDArray divide(float arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator /") NDArray divide(@Cast("const float16") short arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator /") NDArray divide(double arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator /") NDArray divide(int arg0, @Const @ByRef NDArray arg1); @Namespace("nd4j") public static native @ByVal NDArray mmul(@Const @ByRef NDArray arg0, @Const @ByRef NDArray arg1); @@ -3861,10 +3825,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * @param writeList * @param readList */ - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list - - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); /** * This method returns buffer pointer offset by given number of elements, wrt own data type @@ -3903,9 +3870,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * axis - axis along which to repeat elements * repeats - number of repetitions */ - public native NDArray repeat(int axis, @StdVector IntPointer repeats); - public native NDArray repeat(int axis, @StdVector IntBuffer repeats); - public native NDArray repeat(int axis, @StdVector int[] repeats); + public native @ByVal NDArray repeat(int axis, @StdVector IntPointer repeats); + public native @ByVal NDArray repeat(int axis, @StdVector IntBuffer repeats); + public native @ByVal NDArray repeat(int axis, @StdVector int[] repeats); /** * This method fills this array with zeros @@ -3918,14 +3885,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * @param array * @return */ - public static native @ByVal NDArray quantize(@ByRef NDArray array); - - /** - * This method returns quantized copy of given array - * - * @param array - * @return - */ + public static native @ByVal NDArray quantize(@Const @ByRef NDArray array); /** * fill target array by repeating current array @@ -3946,10 +3906,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); /** * cast array elements to given dtype */ + public native @ByVal NDArray cast(@Cast("nd4j::DataType") int dtype); - public native NDArray cast(@Cast("nd4j::DataType") int dtype); - - public native void cast(NDArray target, @Cast("nd4j::DataType") int dtype); + public native void cast(@ByRef NDArray target, @Cast("nd4j::DataType") int dtype); /** * returns _context @@ -4120,26 +4079,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); /** * this method assigns given value to all elements in array */ - public native void assign(double value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(double value); - public native void assign(float value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(float value); - public native void assign(@Cast("const float16") short value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(@Cast("const float16") short value); - public native void assign(@Cast("const Nd4jLong") long value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(@Cast("const Nd4jLong") long value); - public native void assign(int value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(int value); - public native void assign(@Cast("const uint8_t") byte value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(@Cast("const uint8_t") byte value); - public native void assign(@Cast("const bool") boolean value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(@Cast("const bool") boolean value); /** * returns new copy of this array, optionally in different order */ - public native NDArray dup(byte newOrder/*='a'*/); - public native NDArray dup(); + public native @ByVal NDArray dup(byte newOrder/*='a'*/); + public native @ByVal NDArray dup(); /** * returns sum of all elements of array @@ -4176,9 +4121,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * index - the number of array to be returned among set of possible arrays * dimensions - array of dimensions to point on */ - public native NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntPointer dimensions); - public native NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntBuffer dimensions); - public native NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector int[] dimensions); + public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntPointer dimensions); + public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntBuffer dimensions); + public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector int[] dimensions); /** * returns the number of arrays pointing on specified dimension(s) @@ -4200,54 +4145,54 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * add given row vector to all rows of this array * row - row vector to add */ - public native void addiRowVector(@Const NDArray row); + public native void addiRowVector(@Const @ByRef NDArray row); /** * add given row vector to all rows of this array, store result in target * row - row vector to add * target - where to store result */ - public native void addRowVector(@Const NDArray row, NDArray target); + public native void addRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); /** * subtract given row vector from all rows of this array, store result in target * row - row vector to subtract * target - where to store result */ - public native void subRowVector(@Const NDArray row, NDArray target); + public native void subRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); /** * multiply all rows of this array on given row vector, store result in target * row - row vector to multiply on * target - where to store result */ - public native void mulRowVector(@Const NDArray row, NDArray target); + public native void mulRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); /** * divide all rows of this array on given row vector, store result in target * row - row vector to divide on * target - where to store result */ - public native void divRowVector(@Const NDArray row, NDArray target); + public native void divRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); /** * add given column vector to all columns of this array, store result in target * column - column vector to add * target - where to store result */ - public native void addColumnVector(@Const NDArray column, NDArray target); + public native void addColumnVector(@Const @ByRef NDArray column, @ByRef NDArray target); /** * add given column vector to all columns of this array, this array becomes affected (in-place operation) * column - column vector to add */ - public native void addiColumnVector(@Const NDArray column); + public native void addiColumnVector(@Const @ByRef NDArray column); /** * multiply all columns of this array on given column vector, this array becomes affected (in-place operation) * column - column vector to multiply on */ - public native void muliColumnVector(@Const NDArray column); + public native void muliColumnVector(@Const @ByRef NDArray column); /** * returns number of bytes used by _buffer & _shapeInfo @@ -4258,6 +4203,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * these methods suited for FlatBuffers use */ public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeAsVector(); + public native @StdVector IntPointer getShapeAsVectorInt(); public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeInfoAsVector(); public native @Cast("int64_t*") @StdVector LongPointer getShapeInfoAsFlatVector(); public native @Cast("int64_t*") @StdVector LongPointer getShapeAsFlatVector(); @@ -4283,9 +4229,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * * if permute have been applied before or there are weird strides, then new buffer is allocated for new array */ - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); /** * calculate strides and set given order @@ -4324,12 +4270,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native void tile(@ByRef NDArray target); - /** - * returns an array which is result of broadcasting of this and other arrays - * other - input array - */ - public native NDArray broadcast(@Const @ByRef NDArray other); - /** * check whether array is identity matrix */ @@ -4340,7 +4280,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native @Cast("bool") boolean isUnitary(); - /** * operator returns subarray with buffer pointing at this->_buffer with offset defined by given intervals * idx - intervals of indexes which define the subarrays to point on, idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * this->rankOf()) @@ -4386,25 +4325,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets); - /** - * addition operator: array + other - * other - input array to add - */ - public native @ByVal @Name("operator +") NDArray add(@Const @ByRef NDArray other); - - /** - * addition operator: array + scalar - * scalar - input scalar to add - */ - - /** - * friend functions which implement addition operator: scalar + array - * scalar - input scalar to add - */ - //template - //friend NDArray nd4j::operator+(const T scalar, const NDArray& arr); - - /** * addition unary operator array += other * other - input array to add @@ -4417,39 +4337,11 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native @Name("operator -=") void subtractPut(@Const @ByRef NDArray other); - /** - * subtraction operator: array - other - * other - input array to subtract - */ - public native @ByVal @Name("operator -") NDArray subtract(@Const @ByRef NDArray other); - - /** - * subtraction operator: array - scalar - * scalar - input scalar to subtract - */ - /** * negative operator, it changes sign of all array elements on opposite */ public native @ByVal @Name("operator -") NDArray subtract(); - /** - * friend functions which implement subtraction operator: scalar - array - * scalar - input scalar to subtract - */ - //friend NDArray nd4j::operator-(const float scalar, const NDArray& arr); - - /** - * pairwise multiplication operator: array * other - * other - input array to multiply on - */ - public native @ByVal @Name("operator *") NDArray multiply(@Const @ByRef NDArray other); - - /** - * multiplication operator: array * scalar - * scalar - input scalar to multiply on - */ - /** * pairwise multiplication unary operator array *= other * other - input array to multiply on @@ -4461,17 +4353,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * scalar - input scalar to multiply on */ - /** - * pairwise division operator: array / other - * other - input array to divide on - */ - public native @ByVal @Name("operator /") NDArray divide(@Const @ByRef NDArray other); - - /** - * division operator: array / scalar - * scalar - input scalar to divide each array element on - */ - /** * pairwise division unary operator: array /= other * other - input array to divide on @@ -4510,7 +4391,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * return vector with buffer which points on corresponding diagonal elements of array * type - means of vector to be returned: column ('c') or row ('r') */ - public native NDArray diagonal(byte type ); + public native @ByVal NDArray diagonal(byte type ); /** * fill target matrix with given value in one or two directions from main diagonal: @@ -4533,13 +4414,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") long[] shapeInfo); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongPointer shape, NDArray target/*=nullptr*/); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape, NDArray target/*=nullptr*/); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector long[] shape, NDArray target/*=nullptr*/); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongPointer shape, @ByRef NDArray target); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape, @ByRef NDArray target); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector long[] shape, @ByRef NDArray target); // #ifndef __JAVACPP_HACK__ // #endif - public native NDArray asT(@Cast("nd4j::DataType") int dtype); + public native @ByVal NDArray asT(@Cast("nd4j::DataType") int dtype); public native void linspace(double start); @@ -4551,17 +4432,15 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native double getTrace(); - public native ResultSet multipleTensorsAlongDimension(@StdVector IntPointer indices, @StdVector IntPointer dimensions); - public native ResultSet multipleTensorsAlongDimension(@StdVector IntBuffer indices, @StdVector IntBuffer dimensions); - public native ResultSet multipleTensorsAlongDimension(@StdVector int[] indices, @StdVector int[] dimensions); + public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector IntPointer indices, @StdVector IntPointer dimensions); + public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector IntBuffer indices, @StdVector IntBuffer dimensions); + public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector int[] indices, @StdVector int[] dimensions); - public native ResultSet allTensorsAlongDimension(@StdVector IntPointer dimensions); - public native ResultSet allTensorsAlongDimension(@StdVector IntBuffer dimensions); - public native ResultSet allTensorsAlongDimension(@StdVector int[] dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntPointer dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntBuffer dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector int[] dimensions); - //ResultSet allTensorsAlongDims(const std::vector& dimensions) const; - - public native ResultSet allExamples(); + public native @ByVal ResultSet allExamples(); /** * set _shapeInfo @@ -4669,7 +4548,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); /** * returns true if these two NDArrays have same rank, dimensions, strides, ews and order */ - public native @Cast("bool") boolean isSameShapeStrict(@Const NDArray other); + public native @Cast("bool") boolean isSameShapeStrict(@Const @ByRef NDArray other); /** * returns true if buffer && shapeInfo were defined (non nullptr) @@ -4728,11 +4607,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, @Const @ByRef NDArray value); - /** - * creates array which points on certain sub-range of this array, sub-range is defined by given indices - */ - - /** * returns true if array is 2D */ @@ -4803,59 +4677,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native @Cast("bool") boolean isS(); - /** - * inline accessing operator for matrix, i - absolute index - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i) const; - - /** - * inline modifying operator for matrix, i - absolute index - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i); - - /** - * inline accessing operator for 2D array, i - row, j - column - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i, const Nd4jLong j) const; - - /** - * inline modifying operator for 2D array, i - row, j - column - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i, const Nd4jLong j); - - /** - * inline accessing operator for 3D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; - - /** - * inline modifying operator for 3D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k); - - /** - * inline modifying operator for 4D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w); - - /** - * inline accessing operator for 4D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) const; - - /** - * inline modifying operator for ND array - * idx - array with corresponding indexes, for example {2,10,0,5,...,8}, number of indexes should be equal to array rank - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong* idx); - - /** - * inline accessing operator for ND array - * idx - array with corresponding indexes, for example {2,10,0,5,...,8}, number of indexes should be equal to array rank - */ - //FORCEINLINE NDArray operator()(const Nd4jLong* idx) const; - - public native @Cast("bool") boolean isAttached(); public native NDArray detach(); @@ -4871,268 +4692,75 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////// -// accessing operator for matrix, i - absolute index -/* -NDArray NDArray::operator()(const Nd4jLong i) const { - if (i >= shape::length(_shapeInfo)) - throw std::invalid_argument("NDArray::operator(i): input index is out of array length !"); - - auto ews = shape::elementWiseStride(_shapeInfo); - char order = ordering(); - - if(ews == 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else if(ews > 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * ews * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else { - Nd4jLong idx[MAX_RANK]; - shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } -} -*/ -////////////////////////////////////////////////////////////////////////// -// modifying operator for matrix, i - absolute index -/* -NDArray& NDArray::operator()(const Nd4jLong i) { - if (i >= shape::length(_shapeInfo)) - throw std::invalid_argument("NDArray::operator(i): input index is out of array length !"); - - auto ews = shape::elementWiseStride(_shapeInfo); - auto order = ordering(); - - if(ews == 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - // FIXME: bad - return result; - } else if(ews > 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * ews * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else { - Nd4jLong idx[MAX_RANK]; - shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } -}*/ ////////////////////////////////////////////////////////////////////////// -// accessing operator for 2D matrix, i - row, j - column -/* -NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j) const { - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); - - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - // TODO: do we really want a view here? - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ -////////////////////////////////////////////////////////////////////////// -// modifying operator for 2D matrix, i - row, j - column -/* -NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j) { - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); - - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - //FIXME: bad, will crash! - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -// accessing operator for 3D array, i - row, j - column -/* -NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || j >= shapeOf()[2]) - throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); - - Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -// modifying operator for 3D array -/* -NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) - throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); - Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - //FIXME: bad, will crash! - return result; -} -*/ -/* -NDArray NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) const { - - if (rankOf() != 4 || t >= shapeOf()[0] || u >= shapeOf()[1] || v >= shapeOf()[2] || w >= shapeOf()[3]) - throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); - - Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ -/* -NDArray& NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) { - - if (rankOf() != 4 || t >= shapeOf()[0] || u >= shapeOf()[1] || v >= shapeOf()[2] || w >= shapeOf()[3]) - throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); - - Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - // FIXME - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -/* -NDArray NDArray::operator()(const Nd4jLong* idx) const { - for(int i = 0; i < rankOf(); ++i) - if (idx[i] >= sizeAt(i)) - throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -/* -NDArray& NDArray::operator()(const Nd4jLong* idx) { - - for(int i = 0; i < rankOf(); ++i) - if (idx[i] >= sizeAt(i)) - throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - // FIXME - return result; -} -*/ - ////////////////////////////////////////////////////////////////////////// - +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// +// still the definition of inline function must be in header file - ////////////////////////////////////////////////////////////////////////// - // still the definition of inline function must be in header file - ////////////////////////////////////////////////////////////////////////// @@ -5251,7 +4879,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include -// #include +// #include // #include // #include // #include @@ -5402,6 +5030,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include // #include +// #include // #ifdef __CUDACC__ // #endif @@ -5853,7 +5482,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { //#include // #include -// #include +// #include // #include // #include // #include @@ -5944,7 +5573,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include // #include -// #include +// #include // #include // #include // #include @@ -6060,7 +5689,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include // #include -// #include +// #include // #include // #include // #include @@ -6616,6 +6245,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include // #include +// #include // CUDA-specific includes // #ifdef __CUDACC__ @@ -6666,12 +6296,13 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // this method returns workspace for object allocations public native Workspace oWorkspace(); - public native void setVariableSpace(VariableSpace variableSpace); public native RandomBuffer getRNG(); public native void setRNG(RandomBuffer rng); + public native void setTargetEngine(@Cast("samediff::Engine") int engine); + public native VariableSpace getVariableSpace(); public native LaunchContext launchContext(); @@ -6753,10 +6384,12 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native void setInputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); public native void setInputArray(int index, NDArray array); public native void setInputArray(int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + public native void setInputArray(int index, Pointer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setOutputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); public native void setOutputArray(int index, NDArray array); public native void setOutputArray(int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + public native void setOutputArray(int index, Pointer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setTArguments(DoublePointer arguments, int numberOfArguments); public native void setTArguments(DoubleBuffer arguments, int numberOfArguments); @@ -6778,9 +6411,11 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); - public native void allowHelpers(@Cast("bool") boolean reallyAllow); public native @Cast("bool") boolean helpersAllowed(); + + public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride); + public native @Cast("bool") boolean shapeFunctionOverride(); } @@ -6820,6 +6455,11 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include // #include +// #include + +// #ifndef __STANDALONE_BUILD__ +// #include +// #endif @Namespace("nd4j::graph") @NoOffset public static class ContextPrototype extends Pointer { static { Loader.load(); } @@ -6865,6 +6505,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native @Cast("bool*") @StdVector BooleanPointer getBArguments(); public native @StdVector IntPointer getAxis(); + public native @Cast("samediff::Engine") int engine(); + public native @Cast("size_t") long numT(); public native @Cast("size_t") long numI(); public native @Cast("size_t") long numB(); @@ -9429,6 +9071,7 @@ public static final int PREALLOC_SIZE = 33554432; // #define SD_PLATFORMHELPER_H // #include +// #include // #include // #include // #include @@ -9444,6 +9087,8 @@ public static final int PREALLOC_SIZE = 33554432; public native @StdString BytePointer name(); + public native @Cast("samediff::Engine") int engine(); + public native @Cast("Nd4jLong") long hash(); /** @@ -10053,10 +9698,11 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include -// #include +// #include // #include // #include // #include +// #include // handlers part // #include @@ -10094,13 +9740,13 @@ public static final int PREALLOC_SIZE = 33554432; public native void registerHelper(PlatformHelper op); - public native @Cast("bool") boolean hasHelper(@Cast("Nd4jLong") long hash); + public native @Cast("bool") boolean hasHelper(@Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); public native DeclarableOp getOperation(@Cast("char*") String name); public native DeclarableOp getOperation(@Cast("char*") BytePointer name); public native DeclarableOp getOperation(@Cast("Nd4jLong") long hash); - public native PlatformHelper getPlatformHelper(@Cast("Nd4jLong") long hash); + public native PlatformHelper getPlatformHelper(@Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); public native @Cast("Nd4jLong*") @StdVector LongPointer getAllHashes(); @@ -10226,6 +9872,7 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include "config.h" // #endif // used for MKLDNN etc @@ -10316,7 +9963,7 @@ public static final int PREALLOC_SIZE = 33554432; // #ifndef DEV_TESTS_SHAPEDESCRIPTOR_H // #define DEV_TESTS_SHAPEDESCRIPTOR_H -// #include +// #include // #include // #include // #include @@ -10629,7 +10276,11 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include +// #include +// #include // #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java index 8c2109f7c..aa6d91519 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java @@ -33,10 +33,12 @@ import org.bytedeco.javacpp.tools.InfoMapper; @Properties(target = "org.nd4j.nativeblas.Nd4jCuda", helper = "org.nd4j.nativeblas.Nd4jCudaHelper", value = {@Platform(define = "LIBND4J_ALL_OPS", include = { "array/DataType.h", + "array/DataBuffer.h", "array/ConstantDescriptor.h", "array/ConstantDataBuffer.h", "array/TadPack.h", "execution/ErrorReference.h", + "execution/Engine.h", "memory/MemoryType.h", "Environment.h", "types/utf8string.h", @@ -165,6 +167,7 @@ public class Nd4jCudaPresets implements LoadEnabled, InfoMapper { .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) .put(new Info("OpaqueLaunchContext").pointerTypes("OpaqueLaunchContext")) + .put(new Info("OpaqueDataBuffer").pointerTypes("OpaqueDataBuffer")) .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", "@Cast(\"char*\") BytePointer")) .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", @@ -185,10 +188,11 @@ public class Nd4jCudaPresets implements LoadEnabled, InfoMapper { "nd4j::graph::FlatResult", "nd4j::graph::FlatVariable", "nd4j::NDArray::subarray").skip()) .put(new Info("std::string").annotations("@StdString").valueTypes("BytePointer", "String") .pointerTypes("@Cast({\"char*\", \"std::string*\"}) BytePointer")) - .put(new Info("std::pair").pointerTypes("IntIntPair").define()) - .put(new Info("std::vector >").pointerTypes("IntVectorVector").define()) + .put(new Info("std::pair").pointerTypes("IntIntPair").define()) + .put(new Info("std::vector >").pointerTypes("IntVectorVector").define()) .put(new Info("std::vector >").pointerTypes("LongVectorVector").define()) - .put(new Info("std::vector").pointerTypes("NDArrayVector").define()) + .put(new Info("std::vector").pointerTypes("NDArrayVector").define()) + .put(new Info("std::vector").pointerTypes("ConstNDArrayVector").define()) .put(new Info("bool").cast().valueTypes("boolean").pointerTypes("BooleanPointer", "boolean[]")) .put(new Info("nd4j::graph::ResultWrapper").base("org.nd4j.nativeblas.ResultWrapperAbstraction").define()) .put(new Info("nd4j::IndicesList").purify()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java deleted file mode 100644 index c19adf4ad..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java +++ /dev/null @@ -1,552 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator; - -import lombok.extern.slf4j.Slf4j; -import lombok.var; -import org.apache.commons.lang3.RandomUtils; -import org.bytedeco.javacpp.Pointer; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.jita.allocator.impl.MemoryTracker; - -import lombok.val; - -import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.jita.flow.FlowController; -import org.nd4j.jita.memory.impl.CudaFullCachingProvider; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; -import org.nd4j.linalg.api.memory.enums.MemoryKind; -import org.nd4j.linalg.api.memory.enums.MirroringPolicy; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.executors.ExecutorServiceProvider; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.api.memory.enums.AllocationPolicy; -import org.nd4j.jita.memory.impl.CudaDirectProvider; -import org.nd4j.jita.memory.impl.CudaCachingZeroProvider; -import org.nd4j.jita.allocator.utils.AllocationUtils; -import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.impl.AllocationPoint; -import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.impl.AllocationShape; -import org.nd4j.linalg.primitives.Pair; - -import java.util.*; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; - -import static org.junit.Assert.*; - -@Slf4j -@Ignore("AB 2019/05/23 - Getting stuck (tests never finishing) on CI - see issue #7657") -public class AllocatorTest { - private static final long SAFETY_OFFSET = 1024L; - - @Test - public void testCounters() { - int deviceId = 0; - MemoryTracker tracker = new MemoryTracker(); - - assertTrue(0 == tracker.getAllocatedAmount(deviceId)); - assertTrue(0 == tracker.getCachedAmount(deviceId)); - //assertTrue(0 == tracker.getTotalMemory(deviceId)); - - tracker.incrementAllocatedAmount(deviceId, 10); - assertTrue(10 == tracker.getAllocatedAmount(deviceId)); - - tracker.incrementCachedAmount(deviceId, 5); - assertTrue(5 == tracker.getCachedAmount(deviceId)); - - tracker.decrementAllocatedAmount(deviceId, 5); - assertTrue(5 == tracker.getAllocatedAmount(deviceId)); - - tracker.decrementCachedAmount(deviceId, 5); - assertTrue(0 == tracker.getCachedAmount(deviceId)); - - //assertTrue(0 == tracker.getTotalMemory(deviceId)); - - for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { - val ttl = tracker.getTotalMemory(e); - log.info("Device_{} {} bytes", e, ttl); - assertNotEquals(0, ttl); - } - } - - @Test - public void testWorkspaceInitSize() { - - long initSize = 1024; - MemoryTracker tracker = MemoryTracker.getInstance(); - - WorkspaceConfiguration workspaceConfig = WorkspaceConfiguration.builder() - .policyAllocation(AllocationPolicy.STRICT) - .initialSize(initSize) - .build(); - - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, "test121")) { - assertEquals(initSize + SAFETY_OFFSET, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("test121"); - ws.destroyWorkspace(); - - assertEquals(0, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - - @Test - public void testWorkspaceSpilledSize() { - - long initSize = 0; - MemoryTracker tracker = MemoryTracker.getInstance(); - - WorkspaceConfiguration workspaceConfig = WorkspaceConfiguration.builder() - .policyAllocation(AllocationPolicy.STRICT) - .initialSize(initSize) - .build(); - - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, "test99323")) { - assertEquals(0L, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f); - - assertEquals(array.length() * array.data().getElementSize(), tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("test99323"); - ws.destroyWorkspace(); - - assertEquals(0, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - @Test - public void testWorkspaceSpilledSizeHost() { - - long initSize = 0; - MemoryTracker tracker = MemoryTracker.getInstance(); - - WorkspaceConfiguration workspaceConfig = WorkspaceConfiguration.builder() - .policyAllocation(AllocationPolicy.STRICT) - .policyMirroring(MirroringPolicy.HOST_ONLY) - .initialSize(initSize) - .build(); - - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, "test99323222")) { - assertEquals(0L, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f); - - assertEquals(0, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("test99323222"); - ws.destroyWorkspace(); - - assertEquals(0, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - - @Ignore - @Test - public void testWorkspaceAlloc() { - - long initSize = 0; - long allocSize = 48; - - val workspaceConfig = WorkspaceConfiguration.builder() - .policyAllocation(AllocationPolicy.STRICT) - .initialSize(initSize) - .policyMirroring(MirroringPolicy.HOST_ONLY) // Commenting this out makes it so that assert is not triggered (for at least 40 secs or so...) - .build(); - - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, "test")) { - final INDArray zeros = Nd4j.zeros(allocSize, 'c'); - System.out.println("Alloc1:" + MemoryTracker.getInstance().getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - assertTrue(allocSize == - MemoryTracker.getInstance().getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - assertTrue(allocSize == - MemoryTracker.getInstance().getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - /*Nd4j.getWorkspaceManager().destroyWorkspace(ws); - assertTrue(0L == - MemoryTracker.getInstance().getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()));*/ - } - - @Test - public void testDirectProvider() { - INDArray input = Nd4j.zeros(1024); - CudaDirectProvider provider = new CudaDirectProvider(); - AllocationShape shape = AllocationUtils.buildAllocationShape(input); - AllocationPoint point = new AllocationPoint(); - point.setShape(shape); - - val allocBefore = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedBefore = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - val pointers = provider.malloc(shape, point, AllocationStatus.DEVICE); - point.setPointers(pointers); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocMiddle = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedMiddle = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - provider.free(point); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocAfter = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedAfter = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - assertTrue(allocBefore < allocMiddle); - assertEquals(allocBefore, allocAfter); - - assertEquals(cachedBefore, cachedMiddle); - assertEquals(cachedBefore, cachedAfter); - } - - @Test - public void testZeroCachingProvider() { - INDArray input = Nd4j.zeros(1024); - CudaCachingZeroProvider provider = new CudaCachingZeroProvider(); - AllocationShape shape = AllocationUtils.buildAllocationShape(input); - AllocationPoint point = new AllocationPoint(); - point.setShape(shape); - - val allocBefore = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedBefore = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - val pointers = provider.malloc(shape, point, AllocationStatus.DEVICE); - point.setPointers(pointers); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocMiddle = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedMiddle = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - provider.free(point); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocAfter = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedAfter = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - assertTrue(allocBefore < allocMiddle); - assertEquals(allocBefore, allocAfter); - - assertEquals(cachedBefore, cachedMiddle); - assertEquals(cachedBefore, cachedAfter); - } - - @Test - public void testFullCachingProvider() { - INDArray input = Nd4j.zeros(1024); - val provider = new CudaFullCachingProvider(); - AllocationShape shape = AllocationUtils.buildAllocationShape(input); - AllocationPoint point = new AllocationPoint(); - point.setShape(shape); - - val allocBefore = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedBefore = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - val pointers = provider.malloc(shape, point, AllocationStatus.DEVICE); - point.setPointers(pointers); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocMiddle = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedMiddle = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - provider.free(point); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocAfter = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedAfter = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - assertTrue(allocBefore < allocMiddle); - assertEquals(allocBefore, allocAfter); - - //assertEquals(0, cachedBefore); - //assertEquals(0, cachedMiddle); - //assertEquals(shape.getNumberOfBytes(), cachedAfter); - - assertEquals(cachedBefore, cachedMiddle); - assertTrue(cachedBefore < cachedAfter); - } - - @Test - public void testCyclicCreation() throws Exception { - Nd4j.create(100); - - log.info("Approximate free memory: {}", MemoryTracker.getInstance().getApproximateFreeMemory(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - log.info("Real free memory: {}", MemoryTracker.getInstance().getPreciseFreeMemory(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val timeStart = System.currentTimeMillis(); - - while (true) { - //val array = Nd4j.create(DataType.FLOAT, 1000, 1000); - val array = Nd4j.create(DataType.FLOAT, RandomUtils.nextInt(100, 1000), RandomUtils.nextInt(100, 1000)); - - val timeEnd = System.currentTimeMillis(); - if (timeEnd - timeStart > 5 * 60 * 1000) { - log.info("Exiting..."); - break; - } - } - - while (true) { - log.info("Cached device memory: {}", MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - log.info("Active device memory: {}", MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - log.info("Cached host memory: {}", MemoryTracker.getInstance().getCachedHostAmount()); - log.info("Active host memory: {}", MemoryTracker.getInstance().getAllocatedHostAmount()); - - System.gc(); - Thread.sleep(30000); - } - } - - @Test - public void testAllocations() { - INDArray x = Nd4j.create(DataType.FLOAT, 10, 5); - assertArrayEquals(new long[]{10, 5}, x.shape()); - - for (DataType dataType : DataType.values()) { - for (int i = 0; i < 10; ++i) { - - x = Nd4j.create(DataType.FLOAT, 10 * i + 1, 5 * i + 2); - assertArrayEquals(new long[]{10 * i + 1, 5 * i + 2}, x.shape()); - - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.shapeInfoDataBuffer()); - assertNotNull(pointX); - assertTrue(x.shapeInfoDataBuffer().isConstant()); - - assertNotNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - assertEquals(64, pointX.getShape().getNumberOfBytes()); - } - } - } - - @Test - public void testAllocations1() { - INDArray x = Nd4j.zeros(1,10); - - for (int i = 0; i < 100000; ++i) { - INDArray toAdd = Nd4j.ones(1,10); - x.putRow(i+1, toAdd); - } - - assertTrue(x.shapeInfoDataBuffer().isConstant()); - - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.shapeInfoDataBuffer()); - assertNotNull(pointX); - - assertNotNull(pointX); - assertTrue(x.shapeInfoDataBuffer().isConstant()); - - assertNotNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - assertEquals(64, pointX.getShape().getNumberOfBytes()); - } - - @Test - public void testReallocate() { - INDArray x = Nd4j.create(DataType.FLOAT, 10, 5); - var pointX = AtomicAllocator.getInstance().getAllocationPoint(x.data()); - - assertNotNull(pointX); - - assertEquals(200, pointX.getShape().getNumberOfBytes()); - - val hostP = pointX.getHostPointer(); - val deviceP = pointX.getDevicePointer(); - - assertEquals(50, x.data().capacity()); - x.data().reallocate(500); - - pointX = AtomicAllocator.getInstance().getAllocationPoint(x.data()); - - assertEquals(500, x.data().capacity()); - assertEquals(2000, pointX.getShape().getNumberOfBytes()); - - assertNotEquals(hostP, pointX.getHostPointer()); - assertNotEquals(deviceP, pointX.getDevicePointer()); - } - - @Test - public void testDataMigration() { - - for (boolean p2pEnabled : new boolean[]{true, false}) { - - CudaEnvironment.getInstance().getConfiguration().allowCrossDeviceAccess(p2pEnabled); - - Thread[] threads = new Thread[4]; - List> sumsPerList = new ArrayList<>(); - List lst = new ArrayList<>(); - - for (int i = 0; i < 4; ++i) { - threads[i] = new Thread() { - @Override - public void run() { - INDArray x = Nd4j.rand(1, 10); - Pair pair = new Pair<>(); - pair.setFirst(Nd4j.sum(x)); - pair.setSecond(x); - sumsPerList.add(pair); - lst.add(x); - } - }; - threads[i].start(); - } - - try { - for (val thread : threads) { - thread.join(); - } - } catch (InterruptedException e) { - log.info("Interrupted"); - } - - Collections.shuffle(lst); - - for (int i = 0; i < lst.size(); ++i) { - INDArray data = lst.get(i); - - for (int j = 0; j < sumsPerList.size(); ++j) { - if (sumsPerList.get(j).getFirst().equals(data)) - assertEquals(sumsPerList.get(j).getSecond(), data); - - } - } - } - } - - - @Ignore - @Test - public void testHostFallback() { - // Take device memory - long bytesFree = MemoryTracker.getInstance().getApproximateFreeMemory(0); - Pointer p = Nd4j.getMemoryManager().allocate((long)(bytesFree*0.75), MemoryKind.DEVICE, true); - - // Fallback to host - INDArray x1 = Nd4j.create(1, (long)(bytesFree*0.15)); - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x1.shapeInfoDataBuffer()); - - assertNotNull(pointX); - assertNotNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - Nd4j.getMemoryManager().release(p, MemoryKind.DEVICE); - } - - @Test - public void testAffinityGuarantees() { - ExecutorService service = ExecutorServiceProvider.getExecutorService(); - final INDArray steady = Nd4j.rand(1,100); - Map deviceData = new HashMap<>(); - - Future>[] results = new Future[10]; - for (int i = 0; i < results.length; ++i) { - results[i] = service.submit(new Callable>() { - @Override - public List call() { - List retVal = new ArrayList<>(); - for (int i = 0; i < 100; ++i) { - INDArray x = Nd4j.rand(1, 100); - System.out.println("Device for x:" + Nd4j.getAffinityManager().getDeviceForArray(x)); - System.out.println("Device for steady: " + Nd4j.getAffinityManager().getDeviceForArray(steady)); - deviceData.put(x, Nd4j.getAffinityManager().getDeviceForArray(x)); - deviceData.put(steady, Nd4j.getAffinityManager().getDeviceForArray(steady)); - retVal.add(x); - } - Thread[] innerThreads = new Thread[4]; - for (int k = 0; k < 4; ++k) { - innerThreads[k] = new Thread() { - @Override - public void run() { - for (val res : retVal) { - assertEquals(deviceData.get(res), Nd4j.getAffinityManager().getDeviceForArray(res)); - assertEquals(deviceData.get(steady), Nd4j.getAffinityManager().getDeviceForArray(steady)); - } - } - }; - innerThreads[k].start(); - } - try { - for (int k = 0; k < 4; ++k) { - innerThreads[k].join(); - } - } catch (InterruptedException e) { - log.info(e.getMessage()); - } - return retVal; - } - }); - - try { - List resArray = results[i].get(); - for (val res : resArray) { - assertEquals(deviceData.get(res), Nd4j.getAffinityManager().getDeviceForArray(res)); - assertEquals(deviceData.get(steady), Nd4j.getAffinityManager().getDeviceForArray(steady)); - } - } catch (Exception e) { - log.info(e.getMessage()); - } - } - } - - @Test - public void testEventsRelease() { - FlowController controller = AtomicAllocator.getInstance().getFlowController(); - long currEventsNumber = controller.getEventsProvider().getEventsNumber(); - - INDArray x = Nd4j.rand(1,10); - controller.prepareAction(x); - assertEquals(currEventsNumber+1, controller.getEventsProvider().getEventsNumber()); - - INDArray arg1 = Nd4j.rand(1,100); - INDArray arg2 = Nd4j.rand(1,200); - INDArray arg3 = Nd4j.rand(1,300); - controller.prepareAction(x, arg1, arg2, arg3); - assertEquals(currEventsNumber+5, controller.getEventsProvider().getEventsNumber()); - } - - @Test - public void testDataBuffers() { - INDArray x = Nd4j.create(DataType.FLOAT, 10, 5); - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.shapeInfoDataBuffer()); - assertEquals(50, x.data().capacity()); - x.data().destroy(); - assertNull(x.data()); - assertEquals(64, pointX.getShape().getNumberOfBytes()); - System.out.println(pointX.getHostPointer()); - System.out.println(pointX.getDevicePointer()); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java index 9584d5692..a616e7aa1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java @@ -19,6 +19,7 @@ package org.nd4j.jita.allocator; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.DeviceLocalNDArray; @@ -29,7 +30,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @Slf4j -public class DeviceLocalNDArrayTests { +public class DeviceLocalNDArrayTests extends BaseND4JTest { @Test public void testDeviceLocalArray_1() throws Exception{ diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java index d952c775a..9854b797a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java @@ -19,13 +19,14 @@ package org.nd4j.jita.allocator.impl; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.*; @Slf4j -public class MemoryTrackerTest { +public class MemoryTrackerTest extends BaseND4JTest { @Test public void testAllocatedDelta() { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java index 2f5d53a40..6f62dc53a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java @@ -2,150 +2,202 @@ package org.nd4j.linalg.jcublas.buffer; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.workspace.CudaWorkspace; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.factory.Nd4j; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicInteger; + import static org.junit.Assert.*; @Slf4j -public class BaseCudaDataBufferTest { +public class BaseCudaDataBufferTest extends BaseND4JTest { - @Test - public void testShapeCache_1() { - val x = Nd4j.create(DataType.FLOAT, 3, 5); - - assertEquals(DataType.FLOAT, x.dataType()); - assertArrayEquals(new long[]{3, 5}, x.shape()); - assertArrayEquals(new long[]{5, 1}, x.stride()); - assertEquals(1, x.elementWiseStride()); - assertEquals('c', x.ordering()); - - val pair = Nd4j.getShapeInfoProvider().createShapeInformation(x.shape(), x.stride(), x.elementWiseStride(), x.ordering(), x.dataType(), x.isEmpty()); - val db = pair.getFirst(); - val jvm = pair.getSecond(); - - log.info("array shapeInfo: {}", x.shapeInfoJava()); - log.info("direct shapeInfo: {}", jvm); - - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.shapeInfoDataBuffer()); - val pointM = AtomicAllocator.getInstance().getAllocationPoint(db); - - assertNotNull(pointX); - assertNotNull(pointM); - - assertNotNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - assertNotNull(pointM.getHostPointer()); - assertNotNull(pointM.getDevicePointer()); - - - log.info("X hPtr: {}; dPtr: {}", pointX.getHostPointer().address(), pointX.getDevicePointer().address()); - log.info("M hPtr: {}; dPtr: {}", pointM.getHostPointer().address(), pointM.getDevicePointer().address()); - - assertEquals(pointM.getHostPointer().address(), pointX.getHostPointer().address()); - assertEquals(pointM.getDevicePointer().address(), pointX.getDevicePointer().address()); - - assertArrayEquals(x.shapeInfoJava(), jvm); + @Before + public void setUp() { + // } @Test - public void testTadCache_1() { - val x = Nd4j.create(DataType.FLOAT, 3, 5); - val row = x.getRow(1); - val tad = x.tensorAlongDimension(1, 1); + public void testBasicAllocation_1() { + val array = Nd4j.create(DataType.FLOAT, 5); - val pointX = AtomicAllocator.getInstance().getAllocationPoint(row.shapeInfoDataBuffer()); - val pointM = AtomicAllocator.getInstance().getAllocationPoint(tad.shapeInfoDataBuffer()); + // basic validation + assertNotNull(array); + assertNotNull(array.data()); + assertNotNull(((BaseCudaDataBuffer) array.data()).getOpaqueDataBuffer()); - assertNotNull(pointX); - assertNotNull(pointM); + // shape part + assertArrayEquals(new long[]{1, 5, 1, 8192, 1, 99}, array.shapeInfoJava()); + assertArrayEquals(new long[]{1, 5, 1, 8192, 1, 99}, array.shapeInfoDataBuffer().asLong()); - assertNotNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - assertNotNull(pointM.getHostPointer()); - assertNotNull(pointM.getDevicePointer()); - - - log.info("X hPtr: {}; dPtr: {}", pointX.getHostPointer().address(), pointX.getDevicePointer().address()); - log.info("M hPtr: {}; dPtr: {}", pointM.getHostPointer().address(), pointM.getDevicePointer().address()); - - assertEquals(pointM.getHostPointer().address(), pointX.getHostPointer().address()); - assertEquals(pointM.getDevicePointer().address(), pointX.getDevicePointer().address()); - - assertArrayEquals(row.shapeInfoJava(), tad.shapeInfoJava()); - } - - - @Test - public void testHostAllocation_1() { - val x = Nd4j.create(DataType.FLOAT, 3, 5); - - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.data()); - - assertNotNull(pointX); - - assertNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - - x.getDouble(0); - - assertNotNull(pointX.getHostPointer()); + // arrat as full of zeros at this point + assertArrayEquals(new float[] {0.f, 0.f, 0.f, 0.f, 0.f}, array.data().asFloat(), 1e-5f); } @Test - public void testHostAllocation_2() { - val x = Nd4j.createFromArray(new double[]{1, 2, 3, 4, 5}); + public void testBasicAllocation_2() { + val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f); - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.data()); + // basic validation + assertNotNull(array); + assertNotNull(array.data()); + assertNotNull(((BaseCudaDataBuffer) array.data()).getOpaqueDataBuffer()); - assertNotNull(pointX); + // shape part + assertArrayEquals(new long[]{1, 5, 1, 8192, 1, 99}, array.shapeInfoJava()); + assertArrayEquals(new long[]{1, 5, 1, 8192, 1, 99}, array.shapeInfoDataBuffer().asLong()); - assertNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - val sum = x.sumNumber().doubleValue(); - - assertNull(pointX.getHostPointer()); - - assertEquals(15, sum, 1e-5); - - x.getDouble(0); - - assertNotNull(pointX.getHostPointer()); + // arrat as full of values at this point + assertArrayEquals(new float[] {1.f, 2.f, 3.f, 4.f, 5.f}, array.data().asFloat(), 1e-5f); } @Test - public void testHostAllocation_3() { - val wsConf = WorkspaceConfiguration.builder() - .initialSize(10 * 1024 * 1024) - .build(); + public void testBasicView_1() { + val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f).reshape(3, 2); - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsConf, "someworkspaceid")) { - val x = Nd4j.create(DataType.DOUBLE, 3, 5); + // basic validation + assertNotNull(array); + assertNotNull(array.data()); + assertNotNull(((BaseCudaDataBuffer) array.data()).getOpaqueDataBuffer()); - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.data()); + // checking TAD equality + val row = array.getRow(1); + assertArrayEquals(new float[]{3.0f, 4.0f}, row.data().dup().asFloat(), 1e-5f); + } - assertNotNull(pointX); + @Test + public void testScalar_1() { + val scalar = Nd4j.scalar(119.f); - assertNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); + // basic validation + assertNotNull(scalar); + assertNotNull(scalar.data()); + assertNotNull(((BaseCudaDataBuffer) scalar.data()).getOpaqueDataBuffer()); - assertEquals(0, ((CudaWorkspace) ws).getHostOffset()); + // shape part + assertArrayEquals(new long[]{0, 8192, 1, 99}, scalar.shapeInfoJava()); + assertArrayEquals(new long[]{0, 8192, 1, 99}, scalar.shapeInfoDataBuffer().asLong()); - x.getDouble(0); + // pointers part + val devPtr = AtomicAllocator.getInstance().getPointer(scalar.data()); + val hostPtr = AtomicAllocator.getInstance().getHostPointer(scalar.data()); + // dev pointer supposed to exist, and host pointer is not + assertNotNull(devPtr); + assertNull(hostPtr); - assertEquals(ws.getPrimaryOffset(), ((CudaWorkspace) ws).getHostOffset()); - assertNotEquals(0, ws.getPrimaryOffset()); + assertEquals(119.f, scalar.getFloat(0), 1e-5f); + } - assertNotNull(pointX.getHostPointer()); + @Test + public void testSerDe_1() throws Exception { + val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f); + val baos = new ByteArrayOutputStream(); + + Nd4j.write(baos, array); + INDArray restored = Nd4j.read(new ByteArrayInputStream(baos.toByteArray())); + + // basic validation + assertNotNull(restored); + assertNotNull(restored.data()); + assertNotNull(((BaseCudaDataBuffer) restored.data()).getOpaqueDataBuffer()); + + // shape part + assertArrayEquals(new long[]{1, 6, 1, 8192, 1, 99}, restored.shapeInfoJava()); + assertArrayEquals(new long[]{1, 6, 1, 8192, 1, 99}, restored.shapeInfoDataBuffer().asLong()); + + // data equality + assertArrayEquals(array.data().asFloat(), restored.data().asFloat(), 1e-5f); + } + + @Test + public void testBasicOpInvocation_1() { + val array1 = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f); + val array2 = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f); + + // shape pointers must be equal here + val devPtr1 = AtomicAllocator.getInstance().getPointer(array1.shapeInfoDataBuffer()); + val devPtr2 = AtomicAllocator.getInstance().getPointer(array2.shapeInfoDataBuffer()); + + val hostPtr1 = AtomicAllocator.getInstance().getHostPointer(array1.shapeInfoDataBuffer()); + val hostPtr2 = AtomicAllocator.getInstance().getHostPointer(array2.shapeInfoDataBuffer()); + + // pointers must be equal on host and device, since we have shape cache + assertEquals(devPtr1.address(), devPtr2.address()); + assertEquals(hostPtr1.address(), hostPtr2.address()); + + assertEquals(array1, array2); + } + + @Test + public void testBasicOpInvocation_2() { + val array1 = Nd4j.createFromArray(1.f, 200.f, 3.f, 4.f, 5.f, 6.f); + val array2 = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f); + + assertNotEquals(array1, array2); + } + + @Test + public void testBasicOpInvocation_3() { + val array = Nd4j.create(DataType.FLOAT, 6); + val exp = Nd4j.createFromArray(1.f, 1.f, 1.f, 1.f, 1.f, 1.f); + + array.addi(1.0f); + + assertEquals(exp, array); + } + + @Test + public void testCustomOpInvocation_1() { + val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f); + + Nd4j.exec(new PrintVariable(array, true)); + Nd4j.exec(new PrintVariable(array)); + } + + @Test + public void testMultiDeviceMigration_1() throws Exception { + if (Nd4j.getAffinityManager().getNumberOfDevices() < 2) + return; + + // creating all arrays within main thread context + val list = new ArrayList(); + for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) + list.add(Nd4j.create(DataType.FLOAT, 3, 5)); + + val cnt = new AtomicInteger(0); + + // now we're creating threads + for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { + val f = e; + val t = new Thread(new Runnable() { + @Override + public void run() { + // issuing one operation, just to see how migration works + list.get(f).addi(1.0f); + + // synchronizing immediately + Nd4j.getExecutioner().commit(); + cnt.incrementAndGet(); + } + }); + + t.start(); + t.join(); } + + // there shoul dbe no exceptions during execution + assertEquals(Nd4j.getAffinityManager().getNumberOfDevices(), cnt.get()); } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml index 64be62442..48cdc3e03 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml @@ -61,6 +61,7 @@ ${mkl.version}-${javacpp-presets.version} ${dependency.platform2} + org.nd4j nd4j-native-api diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java index 627105bda..d3fe308e6 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.cpu.nativecpu; +import org.nd4j.linalg.factory.Environment; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.Resource; @@ -61,6 +62,11 @@ public class CpuBackend extends Nd4jBackend { return NDArray.class; } + @Override + public Environment getEnvironment() { + return CpuEnvironment.getInstance(); + } + @Override public void logBackendInit() { //No additional logging for CPU backend diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuEnvironment.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuEnvironment.java new file mode 100644 index 000000000..363e8857b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuEnvironment.java @@ -0,0 +1,195 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.cpu.nativecpu; + +import org.nd4j.linalg.factory.Environment; +import org.nd4j.nativeblas.Nd4jCpu; + +/** + * CPU backend implementation of {@link Environment} + * + * @author Alex Black + */ +public class CpuEnvironment implements Environment { + + + private static final CpuEnvironment INSTANCE = new CpuEnvironment(Nd4jCpu.Environment.getInstance()); + + private final Nd4jCpu.Environment e; + + public static CpuEnvironment getInstance(){ + return INSTANCE; + } + + protected CpuEnvironment(Nd4jCpu.Environment environment){ + this.e = environment; + } + + @Override + public int blasMajorVersion() { + return e.blasMajorVersion(); + } + + @Override + public int blasMinorVersion() { + return e.blasMinorVersion(); + } + + @Override + public int blasPatchVersion() { + return e.blasMajorVersion(); + } + + @Override + public boolean isVerbose() { + return e.isVerbose(); + } + + @Override + public void setVerbose(boolean reallyVerbose) { + e.setVerbose(reallyVerbose); + } + + @Override + public boolean isDebug() { + return e.isDebug(); + } + + @Override + public boolean isProfiling() { + return e.isProfiling(); + } + + @Override + public boolean isDetectingLeaks() { + return e.isDetectingLeaks(); + } + + @Override + public boolean isDebugAndVerbose() { + return e.isDebugAndVerbose(); + } + + @Override + public void setDebug(boolean reallyDebug) { + e.setDebug(reallyDebug); + } + + @Override + public void setProfiling(boolean reallyProfile) { + e.setProfiling(reallyProfile); + } + + @Override + public void setLeaksDetector(boolean reallyDetect) { + e.setLeaksDetector(reallyDetect); + } + + @Override + public boolean helpersAllowed() { + return e.helpersAllowed(); + } + + @Override + public void allowHelpers(boolean reallyAllow) { + e.allowHelpers(reallyAllow); + } + + @Override + public int tadThreshold() { + return e.tadThreshold(); + } + + @Override + public void setTadThreshold(int threshold) { + e.setTadThreshold(threshold); + } + + @Override + public int elementwiseThreshold() { + return e.elementwiseThreshold(); + } + + @Override + public void setElementwiseThreshold(int threshold) { + e.setElementwiseThreshold(threshold); + } + + @Override + public int maxThreads() { + return e.maxThreads(); + } + + @Override + public void setMaxThreads(int max) { + e.setMaxThreads(max); + } + + @Override + public int maxMasterThreads() { + return e.maxMasterThreads(); + } + + @Override + public void setMaxMasterThreads(int max) { + e.setMaxMasterThreads(max); + } + + @Override + public void setMaxPrimaryMemory(long maxBytes) { + e.setMaxPrimaryMemory(maxBytes); + } + + @Override + public void setMaxSpecialMemory(long maxBytes) { + e.setMaxSpecialyMemory(maxBytes); + } + + @Override + public void setMaxDeviceMemory(long maxBytes) { + e.setMaxDeviceMemory(maxBytes); + } + + @Override + public boolean isCPU() { + return e.isCPU(); + } + + @Override + public void setGroupLimit(int group, long numBytes) { + e.setGroupLimit(group, numBytes); + } + + @Override + public void setDeviceLimit(int deviceId, long numBytes) { + e.setDeviceLimit(deviceId, numBytes); + } + + @Override + public long getGroupLimit(int group) { + return e.getGroupLimit(group); + } + + @Override + public long getDeviceLimit(int deviceId) { + return e.getDeviceLimit(deviceId); + } + + @Override + public long getDeviceCouner(int deviceId) { + return e.getDeviceCounter(deviceId); + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index 03904125d..38cb6610e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -25,27 +25,22 @@ import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.*; import org.nd4j.linalg.api.ops.custom.Flatten; import org.nd4j.linalg.api.ops.impl.shape.Concat; -import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayType; import org.nd4j.linalg.compression.CompressionUtils; -import org.nd4j.linalg.memory.MemcpyDirection; +import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.LongBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.Utf8Buffer; import org.nd4j.linalg.primitives.Pair; import org.bytedeco.javacpp.*; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; -import org.nd4j.linalg.api.shape.options.ArrayType; -import org.nd4j.linalg.cache.TADManager; import org.nd4j.linalg.compression.CompressedDataBuffer; import org.nd4j.linalg.compression.CompressionDescriptor; import org.nd4j.linalg.compression.CompressionType; -import org.nd4j.linalg.compression.CompressionUtils; import org.nd4j.linalg.cpu.nativecpu.blas.*; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.nativeblas.BaseNativeNDArrayFactory; import org.nd4j.nativeblas.LongPointerWrapper; @@ -102,7 +97,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { System.exit(1); } - if (!nativeOps.isOptimalRequirementsMet() && !Boolean.parseBoolean(System.getenv(ND4JEnvironmentVars.ND4J_IGNORE_AVX))) { + if (!nativeOps.isOptimalRequirementsMet() && !Boolean.parseBoolean(System.getenv(ND4JEnvironmentVars.ND4J_IGNORE_AVX)) && + !Boolean.parseBoolean(System.getProperty(ND4JSystemProperties.ND4J_IGNORE_AVX))) { val binaryLevel = nativeOps.binaryLevel(); val optimalLevel = nativeOps.optimalLevel(); @@ -565,11 +561,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { } nativeOps.tear(null, - tensor.data().pointer(), (LongPointer) tensor.shapeInfoDataBuffer().pointer(), - null, null, + ((BaseCpuDataBuffer) tensor.data()).getOpaqueDataBuffer(), (LongPointer) tensor.shapeInfoDataBuffer().pointer(), null, targets, (LongPointer) result[0].shapeInfoDataBuffer().pointer(), - (LongPointer) tadBuffers.getFirst().pointer(), - new LongPointerWrapper(tadBuffers.getSecond().pointer()) + (LongPointer) tadBuffers.getFirst().pointer(), new LongPointerWrapper(tadBuffers.getSecond().pointer()) ); if (nativeOps.lastErrorCode() != 0) @@ -708,10 +702,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { nativeOps.pullRows(dummy, - source.data().addressPointer(), (LongPointer) source.shapeInfoDataBuffer().addressPointer(), - null, null, - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null, + ((BaseCpuDataBuffer) source.data()).getOpaqueDataBuffer(), (LongPointer) source.shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) ret.data()).getOpaqueDataBuffer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null, indexes.length, pIndex, (LongPointer) hostTadShapeInfo, new LongPointerWrapper(hostTadOffsets), diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java index 36599c859..c48178055 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java @@ -18,20 +18,13 @@ package org.nd4j.linalg.cpu.nativecpu; import lombok.NonNull; import lombok.val; -import org.bytedeco.javacpp.IntPointer; -import org.bytedeco.javacpp.LongPointer; -import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.IntBuffer; -import org.nd4j.linalg.api.buffer.LongBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.cache.ConstantHandler; import org.nd4j.linalg.cache.TADManager; import org.nd4j.linalg.cache.TadDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; -import org.nd4j.nativeblas.LongPointerWrapper; import org.nd4j.nativeblas.NativeOps; import java.util.Arrays; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/NDArray.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/NDArray.java index a6cd47fb0..aae332d78 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/NDArray.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/NDArray.java @@ -17,8 +17,12 @@ package org.nd4j.linalg.cpu.nativecpu; +import com.google.flatbuffers.FlatBufferBuilder; import lombok.val; +import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; +import org.nd4j.base.Preconditions; +import org.nd4j.graph.FlatArray; import org.nd4j.linalg.api.buffer.*; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.BaseNDArray; @@ -27,10 +31,17 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.JvmShapeInfo; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.cpu.nativecpu.buffer.DoubleBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.FloatBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.LongBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.Utf8Buffer; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.workspace.WorkspaceUtils; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; import java.util.List; @@ -488,4 +499,36 @@ public class NDArray extends BaseNDArray { public LongShapeDescriptor shapeDescriptor() { return LongShapeDescriptor.fromShape(shape(), stride(), elementWiseStride(), ordering(), dataType(), isEmpty()); } + + protected int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer) { + Preconditions.checkArgument(buffer.dataType() == DataType.UTF8, "This method can be called on UTF8 buffers only"); + try { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(bos); + + val numWords = this.length(); + val ub = (Utf8Buffer) buffer; + // writing length first + val t = length(); + val ptr = (BytePointer) ub.pointer(); + + // now write all strings as bytes + for (int i = 0; i < ub.length(); i++) { + dos.writeByte(ptr.get(i)); + } + + val bytes = bos.toByteArray(); + return FlatArray.createBufferVector(builder, bytes); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public String getString(long index) { + if (!isS()) + throw new UnsupportedOperationException("This method is usable only on String dataType, but got [" + this.dataType() + "]"); + + return ((Utf8Buffer) data).getString(index); + } } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BFloat16Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BFloat16Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BFloat16Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BFloat16Buffer.java index f2c8a9202..819b30339 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BFloat16Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BFloat16Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class BFloat16Buffer extends BaseDataBuffer { +public class BFloat16Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -50,6 +53,10 @@ public class BFloat16Buffer extends BaseDataBuffer { } + public BFloat16Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + public BFloat16Buffer(long length, boolean initialize) { super(length, initialize); } @@ -111,18 +118,6 @@ public class BFloat16Buffer extends BaseDataBuffer { super(data, copy, offset); } - public BFloat16Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public BFloat16Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public BFloat16Buffer(byte[] data, int length) { - super(data, length); - } - public BFloat16Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java new file mode 100644 index 000000000..ec4a0e51a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java @@ -0,0 +1,939 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.cpu.nativecpu.buffer; + +import lombok.val; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.indexer.*; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.buffer.util.AllocUtil; +import org.nd4j.linalg.api.memory.Deallocatable; +import org.nd4j.linalg.api.memory.Deallocator; +import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.memory.pointers.PagedPointer; +import org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.memory.deallocation.DeallocatorService; +import org.nd4j.nativeblas.NativeOps; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; + +import java.nio.ByteBuffer; + +import static org.nd4j.linalg.api.buffer.DataType.INT16; +import static org.nd4j.linalg.api.buffer.DataType.INT8; + +/** + * Base implementation for DataBuffer for CPU-like backend + * + * @author raver119@gmail.com + */ +public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallocatable { + + protected transient OpaqueDataBuffer ptrDataBuffer; + + private final long instanceId = Nd4j.getDeallocatorService().nextValue(); + + protected BaseCpuDataBuffer() { + + } + + + @Override + public String getUniqueId() { + return "BCDB_" + instanceId; + } + + @Override + public Deallocator deallocator() { + return new CpuDeallocator(this); + } + + public OpaqueDataBuffer getOpaqueDataBuffer() { + return ptrDataBuffer; + } + + @Override + public int targetDevice() { + // TODO: once we add NUMA support this might change. Or might not. + return 0; + } + + + /** + * + * @param length + * @param elementSize + */ + public BaseCpuDataBuffer(long length, int elementSize) { + if (length < 1) + throw new IllegalArgumentException("Length must be >= 1"); + initTypeAndSize(); + allocationMode = AllocUtil.getAllocationModeFromContext(); + this.length = length; + this.underlyingLength = length; + this.elementSize = (byte) elementSize; + + if (dataType() != DataType.UTF8) + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, dataType(), false); + + if (dataType() == DataType.DOUBLE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asDoublePointer(); + + indexer = DoubleIndexer.create((DoublePointer) pointer); + } else if (dataType() == DataType.FLOAT) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asFloatPointer(); + + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + } else if (dataType() == DataType.INT32) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); + + setIndexer(IntIndexer.create((IntPointer) pointer)); + } else if (dataType() == DataType.LONG) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); + + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (dataType() == DataType.SHORT) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + + setIndexer(ShortIndexer.create((ShortPointer) pointer)); + } else if (dataType() == DataType.BYTE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); + + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else if (dataType() == DataType.UBYTE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); + + setIndexer(UByteIndexer.create((BytePointer) pointer)); + } else if (dataType() == DataType.UTF8) { + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, INT8, false); + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); + + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } + + Nd4j.getDeallocatorService().pickObject(this); + } + + /** + * + * @param length + * @param elementSize + */ + public BaseCpuDataBuffer(int length, int elementSize, long offset) { + this(length, elementSize); + this.offset = offset; + this.originalOffset = offset; + this.length = length - offset; + this.underlyingLength = length; + } + + + protected BaseCpuDataBuffer(DataBuffer underlyingBuffer, long length, long offset) { + super(underlyingBuffer, length, offset); + + // for vew we need "externally managed" pointer and deallocator registration + ptrDataBuffer = ((BaseCpuDataBuffer) underlyingBuffer).ptrDataBuffer.createView(length * underlyingBuffer.getElementSize(), offset * underlyingBuffer.getElementSize()); + Nd4j.getDeallocatorService().pickObject(this); + + + // update pointer now + actualizePointerAndIndexer(); + } + + protected BaseCpuDataBuffer(ByteBuffer buffer, DataType dtype, long length, long offset) { + this(length, Nd4j.sizeOfDataType(dtype)); + + Pointer temp = null; + + switch (dataType()){ + case DOUBLE: + temp = new DoublePointer(buffer.asDoubleBuffer()); + break; + case FLOAT: + temp = new FloatPointer(buffer.asFloatBuffer()); + break; + case HALF: + temp = new ShortPointer(buffer.asShortBuffer()); + break; + case LONG: + temp = new LongPointer(buffer.asLongBuffer()); + break; + case INT: + temp = new IntPointer(buffer.asIntBuffer()); + break; + case SHORT: + temp = new ShortPointer(buffer.asShortBuffer()); + break; + case UBYTE: //Fall through + case BYTE: + temp = new BytePointer(buffer); + break; + case BOOL: + temp = new BooleanPointer(length()); + break; + case UTF8: + temp = new BytePointer(length()); + break; + case BFLOAT16: + temp = new ShortPointer(length()); + break; + case UINT16: + temp = new ShortPointer(length()); + break; + case UINT32: + temp = new IntPointer(length()); + break; + case UINT64: + temp = new LongPointer(length()); + break; + } + + val ptr = ptrDataBuffer.primaryBuffer(); + + if (offset > 0) + temp = new PagedPointer(temp.address() + offset * getElementSize()); + + Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype)); + } + + @Override + public void pointerIndexerByCurrentType(DataType currentType) { + + type = currentType; + + if (ptrDataBuffer == null) { + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length(), type, false); + Nd4j.getDeallocatorService().pickObject(this); + } + + actualizePointerAndIndexer(); + } + + /** + * Instantiate a buffer with the given length + * + * @param length the length of the buffer + */ + protected BaseCpuDataBuffer(long length) { + this(length, true); + } + + protected BaseCpuDataBuffer(long length, boolean initialize) { + if (length < 0) + throw new IllegalArgumentException("Length must be >= 0"); + initTypeAndSize(); + this.length = length; + this.underlyingLength = length; + allocationMode = AllocUtil.getAllocationModeFromContext(); + if (length < 0) + throw new IllegalArgumentException("Unable to create a buffer of length <= 0"); + + if (dataType() != DataType.UTF8) + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, dataType(), false); + + if (dataType() == DataType.DOUBLE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asDoublePointer(); + + indexer = DoubleIndexer.create((DoublePointer) pointer); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.FLOAT) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asFloatPointer(); + + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + + } else if (dataType() == DataType.HALF) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + + setIndexer(HalfIndexer.create((ShortPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.BFLOAT16) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + + setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.INT) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); + + setIndexer(IntIndexer.create((IntPointer) pointer)); + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.LONG) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); + + setIndexer(LongIndexer.create((LongPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.BYTE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); + + setIndexer(ByteIndexer.create((BytePointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.SHORT) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + + setIndexer(ShortIndexer.create((ShortPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.UBYTE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); + + setIndexer(UByteIndexer.create((BytePointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.UINT16) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + + setIndexer(UShortIndexer.create((ShortPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.UINT32) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); + + // FIXME: we need unsigned indexer here + setIndexer(IntIndexer.create((IntPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.UINT64) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); + + // FIXME: we need unsigned indexer here + setIndexer(LongIndexer.create((LongPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.BOOL) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer(); + + setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.UTF8) { + // we are allocating buffer as INT8 intentionally + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length(), INT8, false); + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length()).asBytePointer(); + + setIndexer(ByteIndexer.create((BytePointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } + + Nd4j.getDeallocatorService().pickObject(this); + } + + public void actualizePointerAndIndexer() { + val cptr = ptrDataBuffer.primaryBuffer(); + + // skip update if pointers are equal + if (cptr != null && pointer != null && cptr.address() == pointer.address()) + return; + + val t = dataType(); + if (t == DataType.BOOL) { + pointer = new PagedPointer(cptr, length).asBoolPointer(); + setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); + } else if (t == DataType.UBYTE) { + pointer = new PagedPointer(cptr, length).asBytePointer(); + setIndexer(UByteIndexer.create((BytePointer) pointer)); + } else if (t == DataType.BYTE) { + pointer = new PagedPointer(cptr, length).asBytePointer(); + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else if (t == DataType.UINT16) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(UShortIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.SHORT) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(ShortIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.UINT32) { + pointer = new PagedPointer(cptr, length).asIntPointer(); + setIndexer(IntIndexer.create((IntPointer) pointer)); + } else if (t == DataType.INT) { + pointer = new PagedPointer(cptr, length).asIntPointer(); + setIndexer(IntIndexer.create((IntPointer) pointer)); + } else if (t == DataType.UINT64) { + pointer = new PagedPointer(cptr, length).asLongPointer(); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (t == DataType.LONG) { + pointer = new PagedPointer(cptr, length).asLongPointer(); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (t == DataType.BFLOAT16) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); + } else if (t == DataType.HALF) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(HalfIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.FLOAT) { + pointer = new PagedPointer(cptr, length).asFloatPointer(); + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + } else if (t == DataType.DOUBLE) { + pointer = new PagedPointer(cptr, length).asDoublePointer(); + setIndexer(DoubleIndexer.create((DoublePointer) pointer)); + } else if (t == DataType.UTF8) { + pointer = new PagedPointer(cptr, length()).asBytePointer(); + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else + throw new IllegalArgumentException("Unknown datatype: " + dataType()); + } + + @Override + public Pointer addressPointer() { + // we're fetching actual pointer right from C++ + val tempPtr = new PagedPointer(ptrDataBuffer.primaryBuffer()); + + switch (this.type) { + case DOUBLE: return tempPtr.asDoublePointer(); + case FLOAT: return tempPtr.asFloatPointer(); + case UINT16: + case SHORT: + case BFLOAT16: + case HALF: return tempPtr.asShortPointer(); + case UINT32: + case INT: return tempPtr.asIntPointer(); + case UBYTE: + case BYTE: return tempPtr.asBytePointer(); + case UINT64: + case LONG: return tempPtr.asLongPointer(); + case BOOL: return tempPtr.asBoolPointer(); + default: return tempPtr.asBytePointer(); + } + } + + protected BaseCpuDataBuffer(long length, boolean initialize, MemoryWorkspace workspace) { + if (length < 1) + throw new IllegalArgumentException("Length must be >= 1"); + initTypeAndSize(); + this.length = length; + this.underlyingLength = length; + allocationMode = AllocUtil.getAllocationModeFromContext(); + + + + if (length < 0) + throw new IllegalArgumentException("Unable to create a buffer of length <= 0"); + + // creating empty native DataBuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); + + if (dataType() == DataType.DOUBLE) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asDoublePointer(); //new DoublePointer(length()); + indexer = DoubleIndexer.create((DoublePointer) pointer); + + } else if (dataType() == DataType.FLOAT) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asFloatPointer(); //new FloatPointer(length()); + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + + } else if (dataType() == DataType.HALF) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new FloatPointer(length()); + setIndexer(HalfIndexer.create((ShortPointer) pointer)); + + } else if (dataType() == DataType.BFLOAT16) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new FloatPointer(length()); + setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); + } else if (dataType() == DataType.INT) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); + setIndexer(IntIndexer.create((IntPointer) pointer)); + + } else if (dataType() == DataType.UINT32) { + attached = true; + parentWorkspace = workspace; + + // FIXME: need unsigned indexer here + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); + setIndexer(IntIndexer.create((IntPointer) pointer)); + + } else if (dataType() == DataType.UINT64) { + attached = true; + parentWorkspace = workspace; + + // FIXME: need unsigned indexer here + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new IntPointer(length()); + setIndexer(LongIndexer.create((LongPointer) pointer)); + + } else if (dataType() == DataType.LONG) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new LongPointer(length()); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (dataType() == DataType.BYTE) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBytePointer(); //new LongPointer(length()); + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else if (dataType() == DataType.UBYTE) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBytePointer(); //new LongPointer(length()); + setIndexer(UByteIndexer.create((BytePointer) pointer)); + } else if (dataType() == DataType.UINT16) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new IntPointer(length()); + setIndexer(UShortIndexer.create((ShortPointer) pointer)); + + } else if (dataType() == DataType.SHORT) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new LongPointer(length()); + setIndexer(ShortIndexer.create((ShortPointer) pointer)); + } else if (dataType() == DataType.BOOL) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBoolPointer(); //new LongPointer(length()); + setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); + } else if (dataType() == DataType.UTF8) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new LongPointer(length()); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } + + // storing pointer into native DataBuffer + ptrDataBuffer.setPrimaryBuffer(pointer, length); + + // adding deallocator reference + Nd4j.getDeallocatorService().pickObject(this); + + workspaceGenerationId = workspace.getGenerationId(); + } + + public BaseCpuDataBuffer(Pointer pointer, Indexer indexer, long length) { + super(pointer, indexer, length); + + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, type, false); + ptrDataBuffer.setPrimaryBuffer(this.pointer, length); + Nd4j.getDeallocatorService().pickObject(this);; + } + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(float[] data, boolean copy, long offset) { + this(data, copy); + this.offset = offset; + this.originalOffset = offset; + this.length = data.length - offset; + this.underlyingLength = data.length; + + } + + public BaseCpuDataBuffer(float[] data, boolean copy, long offset, MemoryWorkspace workspace) { + this(data, copy, workspace); + this.offset = offset; + this.originalOffset = offset; + this.length = data.length - offset; + this.underlyingLength = data.length; + } + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(float[] data, boolean copy) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + initTypeAndSize(); + + pointer = new FloatPointer(data); + + // creating & registering native DataBuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.FLOAT, false); + ptrDataBuffer.setPrimaryBuffer(pointer, data.length); + Nd4j.getDeallocatorService().pickObject(this); + + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + //wrappedBuffer = pointer.asByteBuffer(); + + length = data.length; + underlyingLength = data.length; + } + + public BaseCpuDataBuffer(float[] data, boolean copy, MemoryWorkspace workspace) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + length = data.length; + underlyingLength = data.length; + attached = true; + parentWorkspace = workspace; + + initTypeAndSize(); + + //log.info("Allocating FloatPointer from array of {} elements", data.length); + + pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asFloatPointer().put(data); + + this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); + this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); + Nd4j.getDeallocatorService().pickObject(this); + + workspaceGenerationId = workspace.getGenerationId(); + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + //wrappedBuffer = pointer.asByteBuffer(); + } + + public BaseCpuDataBuffer(double[] data, boolean copy, MemoryWorkspace workspace) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + length = data.length; + underlyingLength = data.length; + attached = true; + parentWorkspace = workspace; + + initTypeAndSize(); + + //log.info("Allocating FloatPointer from array of {} elements", data.length); + + pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asDoublePointer().put(data); + + this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); + this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); + Nd4j.getDeallocatorService().pickObject(this); + + workspaceGenerationId = workspace.getGenerationId(); + indexer = DoubleIndexer.create((DoublePointer) pointer); + //wrappedBuffer = pointer.asByteBuffer(); + } + + + public BaseCpuDataBuffer(int[] data, boolean copy, MemoryWorkspace workspace) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + length = data.length; + underlyingLength = data.length; + attached = true; + parentWorkspace = workspace; + + initTypeAndSize(); + + //log.info("Allocating FloatPointer from array of {} elements", data.length); + + pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asIntPointer().put(data); + + this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); + this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); + Nd4j.getDeallocatorService().pickObject(this); + + workspaceGenerationId = workspace.getGenerationId(); + indexer = IntIndexer.create((IntPointer) pointer); + //wrappedBuffer = pointer.asByteBuffer(); + } + + public BaseCpuDataBuffer(long[] data, boolean copy, MemoryWorkspace workspace) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + length = data.length; + underlyingLength = data.length; + attached = true; + parentWorkspace = workspace; + + initTypeAndSize(); + + //log.info("Allocating FloatPointer from array of {} elements", data.length); + + pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asLongPointer().put(data); + + this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); + this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); + Nd4j.getDeallocatorService().pickObject(this); + + workspaceGenerationId = workspace.getGenerationId(); + indexer = LongIndexer.create((LongPointer) pointer); + //wrappedBuffer = pointer.asByteBuffer(); + } + + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(double[] data, boolean copy, long offset) { + this(data, copy); + this.offset = offset; + this.originalOffset = offset; + this.underlyingLength = data.length; + this.length = underlyingLength - offset; + } + + public BaseCpuDataBuffer(double[] data, boolean copy, long offset, MemoryWorkspace workspace) { + this(data, copy, workspace); + this.offset = offset; + this.originalOffset = offset; + this.underlyingLength = data.length; + this.length = underlyingLength - offset; + } + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(double[] data, boolean copy) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + initTypeAndSize(); + + pointer = new DoublePointer(data); + indexer = DoubleIndexer.create((DoublePointer) pointer); + + // creating & registering native DataBuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.DOUBLE, false); + ptrDataBuffer.setPrimaryBuffer(pointer, data.length); + Nd4j.getDeallocatorService().pickObject(this); + + length = data.length; + underlyingLength = data.length; + } + + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(int[] data, boolean copy, long offset) { + this(data, copy); + this.offset = offset; + this.originalOffset = offset; + this.length = data.length - offset; + this.underlyingLength = data.length; + } + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(int[] data, boolean copy) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + initTypeAndSize(); + + pointer = new IntPointer(data); + setIndexer(IntIndexer.create((IntPointer) pointer)); + + // creating & registering native DataBuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.INT32, false); + ptrDataBuffer.setPrimaryBuffer(pointer, data.length); + Nd4j.getDeallocatorService().pickObject(this); + + length = data.length; + underlyingLength = data.length; + } + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(long[] data, boolean copy) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + initTypeAndSize(); + + pointer = new LongPointer(data); + setIndexer(LongIndexer.create((LongPointer) pointer)); + + // creating & registering native DataBuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.INT64, false); + ptrDataBuffer.setPrimaryBuffer(pointer, data.length); + Nd4j.getDeallocatorService().pickObject(this); + + length = data.length; + underlyingLength = data.length; + } + + + /** + * + * @param data + */ + public BaseCpuDataBuffer(double[] data) { + this(data, true); + } + + /** + * + * @param data + */ + public BaseCpuDataBuffer(int[] data) { + this(data, true); + } + + /** + * + * @param data + */ + public BaseCpuDataBuffer(float[] data) { + this(data, true); + } + + public BaseCpuDataBuffer(float[] data, MemoryWorkspace workspace) { + this(data, true, workspace); + } + + /** + * Reallocate the native memory of the buffer + * @param length the new length of the buffer + * @return this databuffer + * */ + @Override + public DataBuffer reallocate(long length) { + val oldPointer = ptrDataBuffer.primaryBuffer(); + + if (isAttached()) { + val capacity = length * getElementSize(); + val nPtr = getParentWorkspace().alloc(capacity, dataType(), false); + this.ptrDataBuffer.setPrimaryBuffer(nPtr, length); + + switch (dataType()) { + case BOOL: + pointer = nPtr.asBoolPointer(); + indexer = BooleanIndexer.create((BooleanPointer) pointer); + break; + case UTF8: + case BYTE: + case UBYTE: + pointer = nPtr.asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; + case UINT16: + case SHORT: + pointer = nPtr.asShortPointer(); + indexer = ShortIndexer.create((ShortPointer) pointer); + break; + case UINT32: + case INT: + pointer = nPtr.asIntPointer(); + indexer = IntIndexer.create((IntPointer) pointer); + break; + case DOUBLE: + pointer = nPtr.asDoublePointer(); + indexer = DoubleIndexer.create((DoublePointer) pointer); + break; + case FLOAT: + pointer = nPtr.asFloatPointer(); + indexer = FloatIndexer.create((FloatPointer) pointer); + break; + case HALF: + pointer = nPtr.asShortPointer(); + indexer = HalfIndexer.create((ShortPointer) pointer); + break; + case BFLOAT16: + pointer = nPtr.asShortPointer(); + indexer = Bfloat16Indexer.create((ShortPointer) pointer); + break; + case UINT64: + case LONG: + pointer = nPtr.asLongPointer(); + indexer = LongIndexer.create((LongPointer) pointer); + break; + } + + Pointer.memcpy(pointer, oldPointer, this.length() * getElementSize()); + workspaceGenerationId = getParentWorkspace().getGenerationId(); + } else { + this.ptrDataBuffer.expand(length); + val nPtr = new PagedPointer(this.ptrDataBuffer.primaryBuffer(), length); + + switch (dataType()) { + case BOOL: + pointer = nPtr.asBoolPointer(); + indexer = BooleanIndexer.create((BooleanPointer) pointer); + break; + case UTF8: + case BYTE: + case UBYTE: + pointer = nPtr.asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; + case UINT16: + case SHORT: + pointer = nPtr.asShortPointer(); + indexer = ShortIndexer.create((ShortPointer) pointer); + break; + case UINT32: + case INT: + pointer = nPtr.asIntPointer(); + indexer = IntIndexer.create((IntPointer) pointer); + break; + case DOUBLE: + pointer = nPtr.asDoublePointer(); + indexer = DoubleIndexer.create((DoublePointer) pointer); + break; + case FLOAT: + pointer = nPtr.asFloatPointer(); + indexer = FloatIndexer.create((FloatPointer) pointer); + break; + case HALF: + pointer = nPtr.asShortPointer(); + indexer = HalfIndexer.create((ShortPointer) pointer); + break; + case BFLOAT16: + pointer = nPtr.asShortPointer(); + indexer = Bfloat16Indexer.create((ShortPointer) pointer); + break; + case UINT64: + case LONG: + pointer = nPtr.asLongPointer(); + indexer = LongIndexer.create((LongPointer) pointer); + break; + } + } + + this.underlyingLength = length; + this.length = length; + return this; + } + +} diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BoolBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BoolBuffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BoolBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BoolBuffer.java index 51ea5ca25..6f1bb5f99 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BoolBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BoolBuffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class BoolBuffer extends BaseDataBuffer { +public class BoolBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class BoolBuffer extends BaseDataBuffer { */ public BoolBuffer(long length) { super(length); + } + public BoolBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public BoolBuffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class BoolBuffer extends BaseDataBuffer { super(data, copy, offset); } - public BoolBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public BoolBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public BoolBuffer(byte[] data, int length) { - super(data, length); - } - public BoolBuffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java new file mode 100644 index 000000000..3b8a46fa6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.cpu.nativecpu.buffer; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.linalg.api.memory.Deallocator; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; + +/** + * This class is responsible for OpaqueDataBuffer deletion on native side, once it's not used anymore in Java + * + * @author raver119@gmail.com + */ +@Slf4j +public class CpuDeallocator implements Deallocator { + private OpaqueDataBuffer opaqueDataBuffer; + + public CpuDeallocator(BaseCpuDataBuffer buffer) { + opaqueDataBuffer = buffer.getOpaqueDataBuffer(); + } + + @Override + public void deallocate() { + if (opaqueDataBuffer == null) + throw new RuntimeException("opaqueDataBuffer is null"); + + NativeOpsHolder.getInstance().getDeviceNativeOps().deleteDataBuffer(opaqueDataBuffer); + } +} diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DefaultDataBufferFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DefaultDataBufferFactory.java similarity index 93% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DefaultDataBufferFactory.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DefaultDataBufferFactory.java index 65d605e00..54b02e309 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DefaultDataBufferFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DefaultDataBufferFactory.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer.factory; +package org.nd4j.linalg.cpu.nativecpu.buffer; import lombok.NonNull; import org.bytedeco.javacpp.DoublePointer; @@ -26,6 +26,7 @@ import org.bytedeco.javacpp.indexer.FloatIndexer; import org.bytedeco.javacpp.indexer.Indexer; import org.bytedeco.javacpp.indexer.IntIndexer; import org.nd4j.linalg.api.buffer.*; +import org.nd4j.linalg.api.buffer.factory.DataBufferFactory; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.util.ArrayUtil; @@ -93,20 +94,6 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return null; } - @Override - public DataBuffer createInt(long offset, ByteBuffer buffer, int length) { - return new IntBuffer(buffer, length, offset); - } - - @Override - public DataBuffer createFloat(long offset, ByteBuffer buffer, int length) { - return new FloatBuffer(buffer, length, offset); - } - - @Override - public DataBuffer createDouble(long offset, ByteBuffer buffer, int length) { - return new DoubleBuffer(buffer, length, offset); - } @Override public DataBuffer createDouble(long offset, int length) { @@ -236,25 +223,6 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return new IntBuffer(ArrayUtil.toInts(data), copy, offset); } - @Override - public DataBuffer createInt(ByteBuffer buffer, int length) { - return new IntBuffer(buffer, length); - } - - @Override - public DataBuffer createLong(ByteBuffer buffer, int length) { - return new LongBuffer(buffer, length); - } - - @Override - public DataBuffer createFloat(ByteBuffer buffer, int length) { - return new FloatBuffer(buffer, length); - } - - @Override - public DataBuffer createDouble(ByteBuffer buffer, int length) { - return new DoubleBuffer(buffer, length); - } @Override public DataBuffer createDouble(long length) { @@ -281,6 +249,42 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return new FloatBuffer(length, initialize, workspace); } + @Override + public DataBuffer create(ByteBuffer underlyingBuffer, DataType dataType, long length, long offset) { + switch (dataType) { + case DOUBLE: + return new DoubleBuffer(underlyingBuffer, dataType, length, offset); + case FLOAT: + return new FloatBuffer(underlyingBuffer, dataType, length, offset); + case HALF: + return new HalfBuffer(underlyingBuffer, dataType, length, offset); + case BFLOAT16: + return new BFloat16Buffer(underlyingBuffer, dataType, length, offset); + case LONG: + return new LongBuffer(underlyingBuffer, dataType, length, offset); + case INT: + return new IntBuffer(underlyingBuffer, dataType, length, offset); + case SHORT: + return new Int16Buffer(underlyingBuffer, dataType, length, offset); + case UBYTE: + return new UInt8Buffer(underlyingBuffer, dataType, length, offset); + case UINT16: + return new UInt16Buffer(underlyingBuffer, dataType, length, offset); + case UINT32: + return new UInt32Buffer(underlyingBuffer, dataType, length, offset); + case UINT64: + return new UInt64Buffer(underlyingBuffer, dataType, length, offset); + case BYTE: + return new Int8Buffer(underlyingBuffer, dataType, length, offset); + case BOOL: + return new BoolBuffer(underlyingBuffer, dataType, length, offset); + case UTF8: + return new Utf8Buffer(underlyingBuffer, dataType, length, offset); + default: + throw new IllegalStateException("Unknown datatype used: [" + dataType + "]"); + } + } + @Override public DataBuffer create(@NonNull DataType dataType, long length, boolean initialize) { switch (dataType) { @@ -310,11 +314,11 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return new Int8Buffer(length, initialize); case BOOL: return new BoolBuffer(length, initialize); + case UTF8: + return new Utf8Buffer(length, true); default: throw new IllegalStateException("Unknown datatype used: [" + dataType + "]"); - } - } @Override @@ -540,16 +544,6 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return createDouble(data, true); } - @Override - public DataBuffer createDouble(byte[] data, int length) { - return new DoubleBuffer(ByteBuffer.wrap(data), length); - } - - @Override - public DataBuffer createFloat(byte[] data, int length) { - return new FloatBuffer(ByteBuffer.wrap(data), length); - } - @Override public DataBuffer createFloat(double[] data) { return createFloat(data, true); @@ -958,18 +952,6 @@ public class DefaultDataBufferFactory implements DataBufferFactory { throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); } - /** - * Creates a half-precision data buffer - * - * @param offset - * @param data the data to create the buffer from - * @param length - * @return the new buffer - */ - @Override - public DataBuffer createHalf(long offset, byte[] data, int length) { - throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); - } /** * Creates a half-precision data buffer @@ -983,30 +965,6 @@ public class DefaultDataBufferFactory implements DataBufferFactory { throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); } - /** - * Creates a half-precision data buffer - * - * @param buffer - * @param length - * @return the new buffer - */ - @Override - public DataBuffer createHalf(ByteBuffer buffer, int length) { - throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); - } - - /** - * Creates a half-precision data buffer - * - * @param data - * @param length - * @return - */ - @Override - public DataBuffer createHalf(byte[] data, int length) { - throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); - } - @Override public DataBuffer createHalf(long length, boolean initialize, MemoryWorkspace workspace) { throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); @@ -1046,4 +1004,8 @@ public class DefaultDataBufferFactory implements DataBufferFactory { public Class doubleBufferClass() { return DoubleBuffer.class; } + + public DataBuffer createUtf8Buffer(byte[] data, long product) { + return new Utf8Buffer(data, product); + } } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DoubleBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DoubleBuffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DoubleBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DoubleBuffer.java index 8bd4bd6a1..25d1997d1 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DoubleBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DoubleBuffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class DoubleBuffer extends BaseDataBuffer { +public class DoubleBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer * @@ -40,6 +43,10 @@ public class DoubleBuffer extends BaseDataBuffer { super(pointer, indexer, length); } + public DoubleBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + public DoubleBuffer(long length) { super(length); } @@ -100,18 +107,6 @@ public class DoubleBuffer extends BaseDataBuffer { super(data, copy, offset); } - public DoubleBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public DoubleBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public DoubleBuffer(byte[] data, int length) { - super(data, length); - } - public DoubleBuffer(double[] doubles, boolean copy) { super(doubles, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/FloatBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/FloatBuffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/FloatBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/FloatBuffer.java index 5b598c920..1a6d2846a 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/FloatBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/FloatBuffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class FloatBuffer extends BaseDataBuffer { +public class FloatBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -50,6 +53,10 @@ public class FloatBuffer extends BaseDataBuffer { } + public FloatBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + public FloatBuffer(long length, boolean initialize) { super(length, initialize); } @@ -111,18 +118,6 @@ public class FloatBuffer extends BaseDataBuffer { super(data, copy, offset); } - public FloatBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public FloatBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public FloatBuffer(byte[] data, int length) { - super(data, length); - } - public FloatBuffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/HalfBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/HalfBuffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/HalfBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/HalfBuffer.java index d2cb2cfcc..1fdb338b2 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/HalfBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/HalfBuffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class HalfBuffer extends BaseDataBuffer { +public class HalfBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class HalfBuffer extends BaseDataBuffer { */ public HalfBuffer(long length) { super(length); + } + public HalfBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public HalfBuffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class HalfBuffer extends BaseDataBuffer { super(data, copy, offset); } - public HalfBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public HalfBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public HalfBuffer(byte[] data, int length) { - super(data, length); - } - public HalfBuffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int16Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int16Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int16Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int16Buffer.java index f5cd2245f..7bf6eb969 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int16Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int16Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class Int16Buffer extends BaseDataBuffer { +public class Int16Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class Int16Buffer extends BaseDataBuffer { */ public Int16Buffer(long length) { super(length); + } + public Int16Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public Int16Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class Int16Buffer extends BaseDataBuffer { super(data, copy, offset); } - public Int16Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public Int16Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public Int16Buffer(byte[] data, int length) { - super(data, length); - } - public Int16Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int8Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int8Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int8Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int8Buffer.java index aeec19961..7f14d9ae8 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int8Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int8Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class Int8Buffer extends BaseDataBuffer { +public class Int8Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class Int8Buffer extends BaseDataBuffer { */ public Int8Buffer(long length) { super(length); + } + public Int8Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public Int8Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class Int8Buffer extends BaseDataBuffer { super(data, copy, offset); } - public Int8Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public Int8Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public Int8Buffer(byte[] data, int length) { - super(data, length); - } - public Int8Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/IntBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/IntBuffer.java similarity index 90% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/IntBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/IntBuffer.java index 20ec86bfd..de4282993 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/IntBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/IntBuffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class IntBuffer extends BaseDataBuffer { +public class IntBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -53,18 +56,14 @@ public class IntBuffer extends BaseDataBuffer { super(length, initialize, workspace); } + public IntBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + public IntBuffer(int[] ints, boolean copy, MemoryWorkspace workspace) { super(ints, copy, workspace); } - public IntBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public IntBuffer(byte[] data, int length) { - super(data, length); - } - public IntBuffer(double[] data, boolean copy) { super(data, copy); } @@ -97,10 +96,6 @@ public class IntBuffer extends BaseDataBuffer { super(underlyingBuffer, length, offset); } - public IntBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - @Override protected DataBuffer create(long length) { return new IntBuffer(length); diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/LongBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java similarity index 83% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/LongBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java index 42981e135..7ab2e8c61 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/LongBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java @@ -14,17 +14,22 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import lombok.NonNull; -import lombok.val; import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; import org.bytedeco.javacpp.indexer.LongIndexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.pointers.PagedPointer; +import org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.NativeOpsHolder; import java.nio.ByteBuffer; @@ -33,7 +38,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class LongBuffer extends BaseDataBuffer { +public class LongBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -58,17 +63,14 @@ public class LongBuffer extends BaseDataBuffer { super(length, initialize, workspace); } + public LongBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + public LongBuffer(int[] ints, boolean copy, MemoryWorkspace workspace) { super(ints, copy, workspace); } - public LongBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public LongBuffer(byte[] data, int length) { - super(data, length); - } public LongBuffer(double[] data, boolean copy) { super(data, copy); @@ -110,10 +112,6 @@ public class LongBuffer extends BaseDataBuffer { super(underlyingBuffer, length, offset); } - public LongBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - public LongBuffer(@NonNull Pointer hostPointer, long numberOfElements) { this.allocationMode = AllocationMode.MIXED_DATA_TYPES; this.offset = 0; @@ -124,6 +122,13 @@ public class LongBuffer extends BaseDataBuffer { this.pointer = new PagedPointer(hostPointer, numberOfElements).asLongPointer(); indexer = LongIndexer.create((LongPointer) this.pointer); + + // we still want this buffer to have native representation + + ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, this.pointer, numberOfElements); + + Nd4j.getDeallocatorService().pickObject(this); } @Override diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt16Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt16Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt16Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt16Buffer.java index 9d0e8d02c..d4bc705ec 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt16Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt16Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class UInt16Buffer extends BaseDataBuffer { +public class UInt16Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class UInt16Buffer extends BaseDataBuffer { */ public UInt16Buffer(long length) { super(length); + } + public UInt16Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public UInt16Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class UInt16Buffer extends BaseDataBuffer { super(data, copy, offset); } - public UInt16Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public UInt16Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public UInt16Buffer(byte[] data, int length) { - super(data, length); - } - public UInt16Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt32Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt32Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt32Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt32Buffer.java index 7df2621c7..b18fafc5c 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt32Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt32Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class UInt32Buffer extends BaseDataBuffer { +public class UInt32Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class UInt32Buffer extends BaseDataBuffer { */ public UInt32Buffer(long length) { super(length); + } + public UInt32Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public UInt32Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class UInt32Buffer extends BaseDataBuffer { super(data, copy, offset); } - public UInt32Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public UInt32Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public UInt32Buffer(byte[] data, int length) { - super(data, length); - } - public UInt32Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt64Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt64Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt64Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt64Buffer.java index 15af50dd1..84adf29b6 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt64Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt64Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class UInt64Buffer extends BaseDataBuffer { +public class UInt64Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class UInt64Buffer extends BaseDataBuffer { */ public UInt64Buffer(long length) { super(length); + } + public UInt64Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public UInt64Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class UInt64Buffer extends BaseDataBuffer { super(data, copy, offset); } - public UInt64Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public UInt64Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public UInt64Buffer(byte[] data, int length) { - super(data, length); - } - public UInt64Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt8Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt8Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt8Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt8Buffer.java index 56f311e9e..d207d370a 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt8Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt8Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class UInt8Buffer extends BaseDataBuffer { +public class UInt8Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class UInt8Buffer extends BaseDataBuffer { */ public UInt8Buffer(long length) { super(length); + } + public UInt8Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public UInt8Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class UInt8Buffer extends BaseDataBuffer { super(data, copy, offset); } - public UInt8Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public UInt8Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public UInt8Buffer(byte[] data, int length) { - super(data, length); - } - public UInt8Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Utf8Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Utf8Buffer.java similarity index 89% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Utf8Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Utf8Buffer.java index e2cdc9c2f..3f33cc044 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Utf8Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Utf8Buffer.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import lombok.Getter; @@ -23,11 +23,11 @@ import lombok.val; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; -import org.bytedeco.javacpp.indexer.ByteIndexer; import org.bytedeco.javacpp.indexer.Indexer; -import org.bytedeco.javacpp.indexer.LongIndexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.memory.pointers.PagedPointer; import java.io.UnsupportedEncodingException; import java.nio.ByteBuffer; @@ -39,7 +39,7 @@ import java.util.Collection; * * @author Adam Gibson */ -public class Utf8Buffer extends BaseDataBuffer { +public class Utf8Buffer extends BaseCpuDataBuffer { protected Collection references = new ArrayList<>(); @@ -62,21 +62,30 @@ public class Utf8Buffer extends BaseDataBuffer { } public Utf8Buffer(long length, boolean initialize) { - super(length, initialize); + /** + * Special case: we're creating empty buffer for length strings, each of 0 chars + */ + super((length + 1) * 8, true); + numWords = length; } public Utf8Buffer(long length, boolean initialize, MemoryWorkspace workspace) { - super(length, initialize, workspace); + /** + * Special case: we're creating empty buffer for length strings, each of 0 chars + */ + + super((length + 1) * 8, true, workspace); + numWords = length; + } + + public Utf8Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public Utf8Buffer(int[] ints, boolean copy, MemoryWorkspace workspace) { super(ints, copy, workspace); } - public Utf8Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - public Utf8Buffer(byte[] data, long numWords) { super(data.length, false); @@ -155,10 +164,6 @@ public class Utf8Buffer extends BaseDataBuffer { headerPointer.put(cnt, currentLength); } - public Utf8Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - public String getString(long index) { if (index > numWords) throw new IllegalArgumentException("Requested index [" + index + "] is above actual number of words stored: [" + numWords + "]"); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index 6700f9019..fce391a05 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -24,6 +24,7 @@ import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.OpContext; +import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer; import org.nd4j.linalg.primitives.Pair; import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; @@ -84,14 +85,16 @@ public class CpuOpContext extends BaseOpContext implements OpContext { @Override public void setInputArray(int index, @NonNull INDArray array) { - nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); + //nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); + nativeOps.setGraphContextInputBuffer(context, index, array.isEmpty() ? null : ((BaseCpuDataBuffer) array.data()).getOpaqueDataBuffer(), array.shapeInfoDataBuffer().addressPointer(), null); super.setInputArray(index, array); } @Override public void setOutputArray(int index, @NonNull INDArray array) { - nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); + //nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); + nativeOps.setGraphContextOutputBuffer(context, index, array.isEmpty() ? null : ((BaseCpuDataBuffer) array.data()).getOpaqueDataBuffer(), array.shapeInfoDataBuffer().addressPointer(), null); super.setOutputArray(index, array); } @@ -110,4 +113,9 @@ public class CpuOpContext extends BaseOpContext implements OpContext { public void allowHelpers(boolean reallyAllow) { nativeOps.ctxAllowHelpers(context, reallyAllow); } + + @Override + public void shapeFunctionOverride(boolean reallyOverride) { + nativeOps.ctxShapeFunctionOverride(context, reallyOverride); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index d12efba59..dfd81c80b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -25,7 +25,6 @@ import lombok.val; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.LongIndexer; import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.base.Preconditions; import org.nd4j.compression.impl.AbstractCompressor; @@ -57,6 +56,9 @@ import org.nd4j.linalg.compression.CompressionDescriptor; import org.nd4j.linalg.compression.CompressionType; import org.nd4j.linalg.compression.ThresholdCompression; import org.nd4j.linalg.cpu.nativecpu.CpuTADManager; +import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.LongBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.Utf8Buffer; import org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom; import org.nd4j.linalg.exception.ND4JIllegalArgumentException; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -67,15 +69,7 @@ import org.nd4j.linalg.primitives.AtomicBoolean; import org.nd4j.linalg.primitives.Optional; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; -import org.nd4j.nativeblas.LongPointerWrapper; -import org.nd4j.nativeblas.NativeOps; -import org.nd4j.nativeblas.NativeOpsHolder; -import org.nd4j.nativeblas.Nd4jCpu; -import org.nd4j.nativeblas.OpaqueConstantDataBuffer; -import org.nd4j.nativeblas.OpaqueShapeList; -import org.nd4j.nativeblas.OpaqueTadPack; -import org.nd4j.nativeblas.OpaqueVariable; -import org.nd4j.nativeblas.OpaqueVariablesSet; +import org.nd4j.nativeblas.*; import java.util.*; @@ -209,29 +203,20 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { long st = profilingConfigurableHookIn(op, tadBuffers.getFirst()); - Pointer x = op.x().data().addressPointer(); - Pointer z = op.z().data().addressPointer(); + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); if (op.z().isScalar()) { loop.execIndexReduceScalar(dummy, op.opNum(), - op.x().data().addressPointer(), - (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, - null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null); } else { loop.execIndexReduce(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); } if (loop.lastErrorCode() != 0) @@ -398,30 +383,26 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * This gives us a pointer which is passed around in libnd4j. */ Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); if (op instanceof Variance) { if (ret.isScalar()) { loop.execSummaryStatsScalar(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, ((Variance) op).isBiasCorrected()); } else { Variance var = (Variance) op; try { loop.execSummaryStatsTad(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - getPointerForExtraArgs(op, op.z().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null, - var.isBiasCorrected(), null, null);} catch (Throwable t){ + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, op.z().dataType()), + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + var.isBiasCorrected(), null, null); + } catch (Throwable t){ String str = opInfoString(op, Optional.of(dimension)); throw new RuntimeException("Native AccumulationOp execution (double) failed: " + str, t); } @@ -430,24 +411,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } //pairwise reduction like similarity of two arrays else if (op.y() != null && op.getOpType() == Op.Type.REDUCE3) { + val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); if (op.isComplexAccumulation()) { try { loop.execReduce3All(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null, - (LongPointer) tadBuffers.getFirst().addressPointer(), - new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), - (LongPointer) yTadBuffers.getFirst().addressPointer(), - new LongPointerWrapper(yTadBuffers.getSecond().addressPointer()) + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + (LongPointer) tadBuffers.getFirst().addressPointer(), new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), + (LongPointer) yTadBuffers.getFirst().addressPointer(), new LongPointerWrapper(yTadBuffers.getSecond().addressPointer()) ); } catch (Throwable t){ String str = opInfoString(op, Optional.of(dimension)); @@ -455,27 +429,18 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } } else if (ret.isScalar()) { loop.execReduce3Scalar(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null); + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); } else { try { loop.execReduce3Tad(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, null, null, null, null); } catch (Throwable t){ String str = opInfoString(op, Optional.of(dimension)); @@ -488,35 +453,27 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case REDUCE_FLOAT: loop.execReduceFloat(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null); + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: loop.execReduceBool(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null); + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: loop.execReduceSame(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null); + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: loop.execReduceLong(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null); + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException("Unsupported op used in reduce: "+ op.getOpType()); @@ -525,51 +482,34 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case REDUCE_FLOAT: loop.execReduceFloat2(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: loop.execReduceLong2(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), + (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: loop.execReduceSame2(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), + (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: loop.execReduceBool2(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), + (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException("Unsupported op used in reduce: "+ op.getOpType()); @@ -621,39 +561,28 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (extraz.get() == null) extraz.set(new PointerPointer(32)); - //PointerPointer dummy = extraz.get().put(hostTadShapeInfo, hostTadOffsets, devTadShapeInfoZ, devTadOffsetsZ); - + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case SCALAR: loop.execScalarTad(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),null, (LongPointer) hostTadShapeInfo, (LongPointer) hostTadOffsets, (LongPointer) devTadShapeInfoZ, (LongPointer) devTadOffsetsZ); break; case SCALAR_BOOL: loop.execScalarBoolTad(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) hostTadShapeInfo, (LongPointer) hostTadOffsets, (LongPointer) devTadShapeInfoZ, (LongPointer) devTadOffsetsZ); break; @@ -670,6 +599,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { //validateDataType(Nd4j.dataType(), op); + if(op.z() == null){ + switch (op.getOpType()) { + case SCALAR: + op.setZ(op.x().ulike()); + break; + case SCALAR_BOOL: + op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape())); + break; + default: + throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); + } + } + if (op.x().length() != op.z().length()) throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: " + "x.length()=" + op.x().length() + ", z.length()=" + op.z().length() + " - x shape info = [" @@ -681,28 +623,26 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { return op.z(); } + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val scalar = ((BaseCpuDataBuffer) op.scalar().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + switch (op.getOpType()) { case SCALAR: loop.execScalar(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.scalar().data().addressPointer(), (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + scalar, (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType())); break; case SCALAR_BOOL: loop.execScalarBool(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.scalar().data().addressPointer(), (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + scalar, (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType())); break; default: @@ -807,6 +747,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { "; y: " + op.y().length() + ", shape " + Arrays.toString(op.y().shape()) + "; z: " + op.z().length() + ", shape " + Arrays.toString(op.z().shape())); + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + switch (op.getOpType()) { case TRANSFORM_ANY: case TRANSFORM_FLOAT: @@ -816,54 +760,46 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Preconditions.checkArgument(op.x().dataType() == op.y().dataType() || op.y().dataType() == DataType.BOOL, "Op.X and Op.Y must have the same data type, but got " + op.x().dataType() + " vs " + op.y().dataType()); loop.execPairwiseTransform(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType())); break; case TRANSFORM_BOOL: case PAIRWISE_BOOL: loop.execPairwiseTransformBool(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType())); break; } } else { if (op.z() == null) - op.setZ(Nd4j.create(op.resultType(), op.x().shape())); + op.setZ(Nd4j.createUninitialized(op.resultType(), op.x().shape())); op.validateDataTypes(experimentalMode.get()); + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case TRANSFORM_FLOAT: { val xtraz = getPointerForExtraArgs(op, op.z().dataType()); loop.execTransformFloat(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - xtraz); + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), + null, xtraz); break; } case TRANSFORM_STRICT: { val xtraz = getPointerForExtraArgs(op, op.z().dataType()); loop.execTransformStrict(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } @@ -871,10 +807,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val xtraz = getPointerForExtraArgs(op, op.z().dataType()); loop.execTransformSame(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } @@ -883,10 +817,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val opNum = op.opNum(); loop.execTransformAny(dummy, opNum, - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } @@ -895,10 +827,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val opNum = op.opNum(); loop.execTransformBool(dummy, opNum, - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } @@ -955,34 +885,25 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case BROADCAST: loop.execBroadcast(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case BROADCAST_BOOL: loop.execBroadcastBool(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException("Unknown operation type: [" + op.getOpType() + "]"); @@ -1291,29 +1212,27 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Preconditions.checkArgument(op.z().isR(), "Op.Z must have one of floating point types"); + val x = op.x() == null ? null : ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + if (op.x() != null && op.y() != null && op.z() != null) { // triple arg call loop.execRandom3(null, op.opNum(), rng.getStatePointer(), // rng state ptr - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, op.extraArgsDataBuff(op.z().dataType()).addressPointer()); } else if (op.x() != null && op.z() != null) { //double arg call loop.execRandom2(null, op.opNum(), rng.getStatePointer(), // rng state ptr - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, op.extraArgsDataBuff(op.z().dataType()).addressPointer()); } else { // single arg call loop.execRandom(null, op.opNum(), rng.getStatePointer(), // rng state ptr - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, op.extraArgsDataBuff(op.z().dataType()).addressPointer()); } @@ -1678,6 +1597,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public INDArray[] exec(@NonNull CustomOp op) { + boolean shapeOverride = false; if (op.numOutputArguments() == 0 && !op.isInplaceCall()) { try { val list = this.calculateOutputShape(op); @@ -1686,16 +1606,23 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { for (LongShapeDescriptor shape : list) op.addOutputArgument(Nd4j.create(shape, false)); + + shapeOverride = true; } catch (ND4JIllegalStateException e){ throw e; } catch (Exception e) { - throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified"); + throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e); + //throw new RuntimeException(e); } } val name = op.opName(); try (val context = buildContext()) { + // optionally skip shape validation on op execution + if (shapeOverride) + context.shapeFunctionOverride(true); + context.markInplace(op.isInplaceCall()); // transferring rng state @@ -1713,6 +1640,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val result = exec(op, context); val states = context.getRngStates(); + // check if input & output needs update + for (val in:op.inputArguments()) { + if (!in.isEmpty()) + ((BaseCpuDataBuffer) in.data()).actualizePointerAndIndexer(); + } + + for (val out:op.outputArguments()) { + if (!out.isEmpty()) + ((BaseCpuDataBuffer) out.data()).actualizePointerAndIndexer(); + } + // pulling states back Nd4j.getRandom().setStates(states.getFirst(), states.getSecond()); @@ -1795,10 +1733,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } catch (Throwable t){ StringBuilder sb = new StringBuilder(); sb.append("Inputs: [("); - for( int i=0; i 0) sb.append("), ("); - sb.append(Shape.shapeToStringShort(inputArgs[i])); + sb.append(Shape.shapeToStringShort(inputArgs.get(i))); } sb.append(")]"); if(op instanceof DifferentialFunction && ((DifferentialFunction)op).getSameDiff() != null){ @@ -1959,7 +1897,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } @Override - public String getString(Utf8Buffer buffer, long index) { + public String getString(DataBuffer buffer, long index) { + Preconditions.checkArgument(buffer instanceof Utf8Buffer, "Expected Utf8Buffer"); + val addr = ((LongIndexer) buffer.indexer()).get(index); val ptr = new PagedPointer(addr); val str = new Nd4jCpu.utf8string(ptr); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 0ba5d1293..cfabc651c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -110,6 +110,74 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { } } +@Name("std::vector") public static class ConstNDArrayVector extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ConstNDArrayVector(Pointer p) { super(p); } + public ConstNDArrayVector(NDArray value) { this(1); put(0, value); } + public ConstNDArrayVector(NDArray ... array) { this(array.length); put(array); } + public ConstNDArrayVector() { allocate(); } + public ConstNDArrayVector(long n) { allocate(n); } + private native void allocate(); + private native void allocate(@Cast("size_t") long n); + public native @Name("operator=") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x); + + public boolean empty() { return size() == 0; } + public native long size(); + public void clear() { resize(0); } + public native void resize(@Cast("size_t") long n); + + @Index(function = "at") public native @Const NDArray get(@Cast("size_t") long i); + public native ConstNDArrayVector put(@Cast("size_t") long i, NDArray value); + + public native @ByVal Iterator insert(@ByVal Iterator pos, @Const NDArray value); + public native @ByVal Iterator erase(@ByVal Iterator pos); + public native @ByVal Iterator begin(); + public native @ByVal Iterator end(); + @NoOffset @Name("iterator") public static class Iterator extends Pointer { + public Iterator(Pointer p) { super(p); } + public Iterator() { } + + public native @Name("operator++") @ByRef Iterator increment(); + public native @Name("operator==") boolean equals(@ByRef Iterator it); + public native @Name("operator*") @Const NDArray get(); + } + + public NDArray[] get() { + NDArray[] array = new NDArray[size() < Integer.MAX_VALUE ? (int)size() : Integer.MAX_VALUE]; + for (int i = 0; i < array.length; i++) { + array[i] = get(i); + } + return array; + } + @Override public String toString() { + return java.util.Arrays.toString(get()); + } + + public NDArray pop_back() { + long size = size(); + NDArray value = get(size - 1); + resize(size - 1); + return value; + } + public ConstNDArrayVector push_back(NDArray value) { + long size = size(); + resize(size + 1); + return put(size, value); + } + public ConstNDArrayVector put(NDArray value) { + if (size() != 1) { resize(1); } + return put(0, value); + } + public ConstNDArrayVector put(NDArray ... array) { + if (size() != array.length) { resize(array.length); } + for (int i = 0; i < array.length; i++) { + put(i, array[i]); + } + return this; + } +} + @Name("std::vector") public static class NDArrayVector extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -261,12 +329,167 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { QINT16 = 16, BFLOAT16 = 17, UTF8 = 50, + UTF16 = 51, + UTF32 = 52, ANY = 100, AUTO = 200; // #endif +// Parsed from array/DataBuffer.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +// #ifndef DEV_TESTS_DATABUFFER_H +// #define DEV_TESTS_DATABUFFER_H + +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +@Namespace("nd4j") @NoOffset public static class DataBuffer extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DataBuffer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public DataBuffer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public DataBuffer position(long position) { + return (DataBuffer)super.position(position); + } + + + public DataBuffer(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType, isOwnerPrimary, isOwnerSpecial, workspace); } + private native void allocate(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType); } + private native void allocate(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, lenInBytes, dataType, isOwnerPrimary, workspace); } + private native void allocate(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(primary, lenInBytes, dataType); } + private native void allocate(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes, workspace); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/); + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes); + + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/) { super((Pointer)null); allocate(lenInBytes, dataType, workspace, allocBoth); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/); + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(lenInBytes, dataType); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(@Const @ByRef DataBuffer other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef DataBuffer other); + public DataBuffer() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native @ByRef @Name("operator =") DataBuffer put(@Const @ByRef DataBuffer other); + + public native @Cast("nd4j::DataType") int getDataType(); + public native void setDataType(@Cast("nd4j::DataType") int dataType); + public native @Cast("size_t") long getLenInBytes(); + + public native Pointer primary(); + public native Pointer special(); + + public native void allocatePrimary(); + public native void allocateSpecial(); + + public native void writePrimary(); + public native void writeSpecial(); + public native void readPrimary(); + public native void readSpecial(); + public native @Cast("bool") boolean isPrimaryActual(); + public native @Cast("bool") boolean isSpecialActual(); + + public native void expand(@Cast("const uint64_t") long size); + + public native int deviceId(); + public native void setDeviceId(int deviceId); + public native void migrate(); + + public native void syncToPrimary(@Const LaunchContext context, @Cast("const bool") boolean forceSync/*=false*/); + public native void syncToPrimary(@Const LaunchContext context); + public native void syncToSpecial(@Cast("const bool") boolean forceSync/*=false*/); + public native void syncToSpecial(); + + public native void setToZeroBuffers(@Cast("const bool") boolean both/*=false*/); + public native void setToZeroBuffers(); + + public native void copyBufferFrom(@Const @ByRef DataBuffer other, @Cast("size_t") long sizeToCopyinBytes/*=0*/, @Cast("const Nd4jLong") long offsetThis/*=0*/, @Cast("const Nd4jLong") long offsetOther/*=0*/); + public native void copyBufferFrom(@Const @ByRef DataBuffer other); + + public static native void memcpy(@Const @ByRef DataBuffer dst, @Const @ByRef DataBuffer src); + + public native void setPrimaryBuffer(Pointer buffer, @Cast("size_t") long length); + public native void setSpecialBuffer(Pointer buffer, @Cast("size_t") long length); + + /** + * This method deletes buffers, if we're owners + */ + public native @Name("close") void _close(); +} +///// IMLEMENTATION OF INLINE METHODS ///// + +//////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////// + + + + + +// #endif //DEV_TESTS_DATABUFFER_H + + // Parsed from array/ConstantDataBuffer.h /******************************************************************************* @@ -350,7 +573,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // #define DEV_TESTS_CONSTANTDESCRIPTOR_H // #include -// #include +// #include // #include // #include // #include @@ -524,6 +747,39 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // #endif //DEV_TESTS_ERRORREFERENCE_H +// Parsed from execution/Engine.h + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_ENGINE_H +// #define SD_ENGINE_H + /** enum samediff::Engine */ + public static final int + ENGINE_CPU = 0, + ENGINE_CUDA = 1; + + +// #endif //SD_ENGINE_H + + // Parsed from Environment.h /******************************************************************************* @@ -555,6 +811,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // #include // #include // #include +// #include @Namespace("nd4j") @NoOffset public static class Environment extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -593,10 +850,30 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public native int maxMasterThreads(); public native void setMaxMasterThreads(int max); + /* + * Legacy memory limits API, still used in new API as simplified version + */ public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes); public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes); public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes); + public native @Cast("uint64_t") long maxPrimaryMemory(); + public native @Cast("uint64_t") long maxSpecialMemory(); + //////////////////////// + + /* + * Methods for memory limits/counters + */ + public native void setGroupLimit(int group, @Cast("Nd4jLong") long numBytes); + public native void setDeviceLimit(int deviceId, @Cast("Nd4jLong") long numBytes); + + public native @Cast("Nd4jLong") long getGroupLimit(int group); + public native @Cast("Nd4jLong") long getDeviceLimit(int deviceId); + + public native @Cast("Nd4jLong") long getGroupCounter(int group); + public native @Cast("Nd4jLong") long getDeviceCounter(int deviceId); + //////////////////////// + public native @Cast("bool") boolean isUseMKLDNN(); public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); @@ -756,6 +1033,7 @@ bool verbose = false; // #include // #include // #include +// #include // #include // #include // #include @@ -763,6 +1041,7 @@ bool verbose = false; // #include // #include // #include +// #include /** * This function returns last error code stored, @@ -804,25 +1083,19 @@ public native void setTADThreshold(int num); */ public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -837,31 +1110,22 @@ public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer ex */ public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -878,74 +1142,50 @@ public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPoi public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -962,63 +1202,45 @@ public native void execBroadcastBool( public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); /** @@ -1032,92 +1254,68 @@ public native void execPairwiseTransformBool( */ public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -1130,118 +1328,82 @@ public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -1256,31 +1418,22 @@ public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPoi */ public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -1293,31 +1446,22 @@ public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointer */ public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * * @param opNum @@ -1333,82 +1477,58 @@ public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraP */ public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong*") LongPointer yTadOffsets); public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadOnlyShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer yTadOnlyShapeInfo, @Cast("Nd4jLong*") LongBuffer yTadOffsets); public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadOnlyShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] yTadOnlyShapeInfo, @Cast("Nd4jLong*") long[] yTadOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer xTadShapeInfo, @Cast("Nd4jLong*") LongPointer xOffsets, @Cast("Nd4jLong*") LongPointer yTadShapeInfo, @Cast("Nd4jLong*") LongPointer yOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer xTadShapeInfo, @Cast("Nd4jLong*") LongBuffer xOffsets, @Cast("Nd4jLong*") LongBuffer yTadShapeInfo, @Cast("Nd4jLong*") LongBuffer yOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] xTadShapeInfo, @Cast("Nd4jLong*") long[] xOffsets, @Cast("Nd4jLong*") long[] yTadShapeInfo, @Cast("Nd4jLong*") long[] yOffsets); @@ -1425,58 +1545,40 @@ public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, Pointer extraParams); public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, Pointer extraParams); public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, Pointer extraParams); /** @@ -1488,27 +1590,21 @@ public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * @@ -1521,27 +1617,21 @@ public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer e */ public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * @@ -1556,35 +1646,26 @@ public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPo */ public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets); public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets); public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets); @@ -1600,112 +1681,82 @@ public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extr */ public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); /** @@ -1723,81 +1774,57 @@ public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extr */ public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("Nd4jLong*") LongPointer tadOffsetsZ); public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("Nd4jLong*") LongBuffer tadOffsetsZ); public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") long[] dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, @Cast("Nd4jLong*") long[] dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("Nd4jLong*") LongPointer tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("Nd4jLong*") LongBuffer tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") long[] dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, @Cast("Nd4jLong*") long[] dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); @@ -2160,10 +2187,8 @@ public native void deleteTadPack(OpaqueTadPack ptr); * @param zTadOffsets */ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") LongPointer zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") LongPointer dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer zShapeInfo, @Cast("Nd4jLong*") LongPointer dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") LongPointer indexes, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @@ -2171,10 +2196,8 @@ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer zTadShapeInfo, @Cast("Nd4jLong*") LongPointer zTadOffsets); public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") LongBuffer zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") LongBuffer dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer zShapeInfo, @Cast("Nd4jLong*") LongBuffer dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") LongBuffer indexes, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @@ -2182,10 +2205,8 @@ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer zTadShapeInfo, @Cast("Nd4jLong*") LongBuffer zTadOffsets); public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") long[] zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") long[] dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jLong*") long[] dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] zShapeInfo, @Cast("Nd4jLong*") long[] dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") long[] indexes, @Cast("Nd4jLong*") long[] tadShapeInfo, @@ -2451,20 +2472,17 @@ public native void execAggregateBatch(@Cast("Nd4jPointer*") PointerPointer extra public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** @@ -2483,32 +2501,23 @@ public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeBuffer, @Cast("Nd4jLong*") LongPointer dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeBuffer, @Cast("Nd4jLong*") LongBuffer dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeBuffer, @Cast("Nd4jLong*") long[] dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeBuffer, @Cast("Nd4jLong*") long[] dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** @@ -2525,26 +2534,20 @@ public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointer public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeBuffer, @Cast("Nd4jLong*") long[] dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); @@ -2587,52 +2590,6 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe */ public native void destroyRandom(@Cast("Nd4jPointer") Pointer ptrRandom); -/** - * Grid operations - */ - - - - -/** - * - * @param extras - * @param opTypeA - * @param opNumA - * @param opTypeB - * @param opNumB - * @param N - * @param dx - * @param xShapeInfo - * @param dy - * @param yShapeInfo - * @param dz - * @param zShapeInfo - * @param extraA - * @param extraB - * @param scalarA - * @param scalarB - */ - /* -ND4J_EXPORT void execMetaPredicateShape(Nd4jPointer *extras, - const int opTypeA, - const int opNumA, - const int opTypeB, - const int opNumB, - Nd4jLong N, - void *hX, Nd4jLong *hXShapeBuffer, - void *dX, Nd4jLong *dXShapeBuffer, - void *hY, Nd4jLong *hYShapeBuffer, - void *dY, Nd4jLong *dYShapeBuffer, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, - void *extraA, - void *extraB, - double scalarA, - double scalarB); - -*/ - /** * * @param data @@ -2795,23 +2752,20 @@ public native @Cast("Nd4jPointer") Pointer pointerForAddress(@Cast("Nd4jLong") l * @return */ public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongPointer zShapeInfo, - @Cast("Nd4jLong*") LongPointer tadShapeInfo, - @Cast("Nd4jLong*") LongPointer tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongPointer zShapeInfo, + @Cast("Nd4jLong*") LongPointer tadShapeInfo, + @Cast("Nd4jLong*") LongPointer tadOffsets); public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongBuffer zShapeInfo, - @Cast("Nd4jLong*") LongBuffer tadShapeInfo, - @Cast("Nd4jLong*") LongBuffer tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongBuffer zShapeInfo, + @Cast("Nd4jLong*") LongBuffer tadShapeInfo, + @Cast("Nd4jLong*") LongBuffer tadOffsets); public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") long[] zShapeInfo, - @Cast("Nd4jLong*") long[] tadShapeInfo, - @Cast("Nd4jLong*") long[] tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jLong*") long[] dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") long[] zShapeInfo, + @Cast("Nd4jLong*") long[] tadShapeInfo, + @Cast("Nd4jLong*") long[] tadOffsets); public native @Cast("Nd4jLong") long encodeBitmap(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer dx, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong") long N, IntPointer dz, float threshold); public native @Cast("Nd4jLong") long encodeBitmap(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer dx, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong") long N, IntBuffer dz, float threshold); @@ -3103,10 +3057,13 @@ public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr); public native OpaqueContext createGraphContext(int nodeId); public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); +public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); +public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); +public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); @@ -3139,6 +3096,28 @@ public native @Cast("Nd4jPointer") Pointer lcCopyStream(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); +public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); +public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset); +public native @Cast("Nd4jPointer") Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer); +public native @Cast("Nd4jPointer") Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer); +public native void dbExpandBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); +public native void dbAllocatePrimaryBuffer(OpaqueDataBuffer dataBuffer); +public native void dbAllocateSpecialBuffer(OpaqueDataBuffer dataBuffer); +public native void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer primaryBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSyncToSpecial(OpaqueDataBuffer dataBuffer); +public native void dbSyncToPrimary(OpaqueDataBuffer dataBuffer); +public native int dbLocality(OpaqueDataBuffer dataBuffer); +public native int dbDeviceId(OpaqueDataBuffer dataBuffer); +public native void dbSetDeviceId(OpaqueDataBuffer dataBuffer, int deviceId); +public native void dbTickHostRead(OpaqueDataBuffer dataBuffer); +public native void dbTickHostWrite(OpaqueDataBuffer dataBuffer); +public native void dbTickDeviceRead(OpaqueDataBuffer dataBuffer); +public native void dbTickDeviceWrite(OpaqueDataBuffer dataBuffer); +public native void dbClose(OpaqueDataBuffer dataBuffer); +public native void deleteDataBuffer(OpaqueDataBuffer dataBuffer); +public native void dbExpand(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); + public native int binaryLevel(); public native int optimalLevel(); @@ -3635,27 +3614,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); // #include // #include // #include +// #include +// #include +// #include - @Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(float arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(@Cast("const float16") short arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(double arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(int arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator +") NDArray add(float arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator +") NDArray add(@Cast("const float16") short arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator +") NDArray add(double arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator +") NDArray add(int arg0, @Const @ByRef NDArray arg1); - - @Namespace("nd4j") public static native @ByVal @Name("operator *") NDArray multiply(float arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator *") NDArray multiply(@Cast("const float16") short arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator *") NDArray multiply(double arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator *") NDArray multiply(int arg0, @Const @ByRef NDArray arg1); - - @Namespace("nd4j") public static native @ByVal @Name("operator /") NDArray divide(float arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator /") NDArray divide(@Cast("const float16") short arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator /") NDArray divide(double arg0, @Const @ByRef NDArray arg1); - @Namespace("nd4j") public static native @ByVal @Name("operator /") NDArray divide(int arg0, @Const @ByRef NDArray arg1); @Namespace("nd4j") public static native @ByVal NDArray mmul(@Const @ByRef NDArray arg0, @Const @ByRef NDArray arg1); @@ -3864,10 +3828,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * @param writeList * @param readList */ - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list - - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); /** * This method returns buffer pointer offset by given number of elements, wrt own data type @@ -3906,9 +3873,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * axis - axis along which to repeat elements * repeats - number of repetitions */ - public native NDArray repeat(int axis, @StdVector IntPointer repeats); - public native NDArray repeat(int axis, @StdVector IntBuffer repeats); - public native NDArray repeat(int axis, @StdVector int[] repeats); + public native @ByVal NDArray repeat(int axis, @StdVector IntPointer repeats); + public native @ByVal NDArray repeat(int axis, @StdVector IntBuffer repeats); + public native @ByVal NDArray repeat(int axis, @StdVector int[] repeats); /** * This method fills this array with zeros @@ -3921,14 +3888,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * @param array * @return */ - public static native @ByVal NDArray quantize(@ByRef NDArray array); - - /** - * This method returns quantized copy of given array - * - * @param array - * @return - */ + public static native @ByVal NDArray quantize(@Const @ByRef NDArray array); /** * fill target array by repeating current array @@ -3949,10 +3909,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); /** * cast array elements to given dtype */ + public native @ByVal NDArray cast(@Cast("nd4j::DataType") int dtype); - public native NDArray cast(@Cast("nd4j::DataType") int dtype); - - public native void cast(NDArray target, @Cast("nd4j::DataType") int dtype); + public native void cast(@ByRef NDArray target, @Cast("nd4j::DataType") int dtype); /** * returns _context @@ -4123,26 +4082,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); /** * this method assigns given value to all elements in array */ - public native void assign(double value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(double value); - public native void assign(float value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(float value); - public native void assign(@Cast("const float16") short value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(@Cast("const float16") short value); - public native void assign(@Cast("const Nd4jLong") long value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(@Cast("const Nd4jLong") long value); - public native void assign(int value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(int value); - public native void assign(@Cast("const uint8_t") byte value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(@Cast("const uint8_t") byte value); - public native void assign(@Cast("const bool") boolean value, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(@Cast("const bool") boolean value); /** * returns new copy of this array, optionally in different order */ - public native NDArray dup(byte newOrder/*='a'*/); - public native NDArray dup(); + public native @ByVal NDArray dup(byte newOrder/*='a'*/); + public native @ByVal NDArray dup(); /** * returns sum of all elements of array @@ -4179,9 +4124,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * index - the number of array to be returned among set of possible arrays * dimensions - array of dimensions to point on */ - public native NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntPointer dimensions); - public native NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntBuffer dimensions); - public native NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector int[] dimensions); + public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntPointer dimensions); + public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntBuffer dimensions); + public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector int[] dimensions); /** * returns the number of arrays pointing on specified dimension(s) @@ -4203,54 +4148,54 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * add given row vector to all rows of this array * row - row vector to add */ - public native void addiRowVector(@Const NDArray row); + public native void addiRowVector(@Const @ByRef NDArray row); /** * add given row vector to all rows of this array, store result in target * row - row vector to add * target - where to store result */ - public native void addRowVector(@Const NDArray row, NDArray target); + public native void addRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); /** * subtract given row vector from all rows of this array, store result in target * row - row vector to subtract * target - where to store result */ - public native void subRowVector(@Const NDArray row, NDArray target); + public native void subRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); /** * multiply all rows of this array on given row vector, store result in target * row - row vector to multiply on * target - where to store result */ - public native void mulRowVector(@Const NDArray row, NDArray target); + public native void mulRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); /** * divide all rows of this array on given row vector, store result in target * row - row vector to divide on * target - where to store result */ - public native void divRowVector(@Const NDArray row, NDArray target); + public native void divRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); /** * add given column vector to all columns of this array, store result in target * column - column vector to add * target - where to store result */ - public native void addColumnVector(@Const NDArray column, NDArray target); + public native void addColumnVector(@Const @ByRef NDArray column, @ByRef NDArray target); /** * add given column vector to all columns of this array, this array becomes affected (in-place operation) * column - column vector to add */ - public native void addiColumnVector(@Const NDArray column); + public native void addiColumnVector(@Const @ByRef NDArray column); /** * multiply all columns of this array on given column vector, this array becomes affected (in-place operation) * column - column vector to multiply on */ - public native void muliColumnVector(@Const NDArray column); + public native void muliColumnVector(@Const @ByRef NDArray column); /** * returns number of bytes used by _buffer & _shapeInfo @@ -4261,6 +4206,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * these methods suited for FlatBuffers use */ public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeAsVector(); + public native @StdVector IntPointer getShapeAsVectorInt(); public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeInfoAsVector(); public native @Cast("int64_t*") @StdVector LongPointer getShapeInfoAsFlatVector(); public native @Cast("int64_t*") @StdVector LongPointer getShapeAsFlatVector(); @@ -4286,9 +4232,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * * if permute have been applied before or there are weird strides, then new buffer is allocated for new array */ - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); /** * calculate strides and set given order @@ -4327,12 +4273,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native void tile(@ByRef NDArray target); - /** - * returns an array which is result of broadcasting of this and other arrays - * other - input array - */ - public native NDArray broadcast(@Const @ByRef NDArray other); - /** * check whether array is identity matrix */ @@ -4343,7 +4283,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native @Cast("bool") boolean isUnitary(); - /** * operator returns subarray with buffer pointing at this->_buffer with offset defined by given intervals * idx - intervals of indexes which define the subarrays to point on, idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * this->rankOf()) @@ -4389,25 +4328,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets); - /** - * addition operator: array + other - * other - input array to add - */ - public native @ByVal @Name("operator +") NDArray add(@Const @ByRef NDArray other); - - /** - * addition operator: array + scalar - * scalar - input scalar to add - */ - - /** - * friend functions which implement addition operator: scalar + array - * scalar - input scalar to add - */ - //template - //friend NDArray nd4j::operator+(const T scalar, const NDArray& arr); - - /** * addition unary operator array += other * other - input array to add @@ -4420,39 +4340,11 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native @Name("operator -=") void subtractPut(@Const @ByRef NDArray other); - /** - * subtraction operator: array - other - * other - input array to subtract - */ - public native @ByVal @Name("operator -") NDArray subtract(@Const @ByRef NDArray other); - - /** - * subtraction operator: array - scalar - * scalar - input scalar to subtract - */ - /** * negative operator, it changes sign of all array elements on opposite */ public native @ByVal @Name("operator -") NDArray subtract(); - /** - * friend functions which implement subtraction operator: scalar - array - * scalar - input scalar to subtract - */ - //friend NDArray nd4j::operator-(const float scalar, const NDArray& arr); - - /** - * pairwise multiplication operator: array * other - * other - input array to multiply on - */ - public native @ByVal @Name("operator *") NDArray multiply(@Const @ByRef NDArray other); - - /** - * multiplication operator: array * scalar - * scalar - input scalar to multiply on - */ - /** * pairwise multiplication unary operator array *= other * other - input array to multiply on @@ -4464,17 +4356,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * scalar - input scalar to multiply on */ - /** - * pairwise division operator: array / other - * other - input array to divide on - */ - public native @ByVal @Name("operator /") NDArray divide(@Const @ByRef NDArray other); - - /** - * division operator: array / scalar - * scalar - input scalar to divide each array element on - */ - /** * pairwise division unary operator: array /= other * other - input array to divide on @@ -4513,7 +4394,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * return vector with buffer which points on corresponding diagonal elements of array * type - means of vector to be returned: column ('c') or row ('r') */ - public native NDArray diagonal(byte type ); + public native @ByVal NDArray diagonal(byte type ); /** * fill target matrix with given value in one or two directions from main diagonal: @@ -4536,13 +4417,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") long[] shapeInfo); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongPointer shape, NDArray target/*=nullptr*/); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape, NDArray target/*=nullptr*/); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector long[] shape, NDArray target/*=nullptr*/); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongPointer shape, @ByRef NDArray target); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape, @ByRef NDArray target); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector long[] shape, @ByRef NDArray target); // #ifndef __JAVACPP_HACK__ // #endif - public native NDArray asT(@Cast("nd4j::DataType") int dtype); + public native @ByVal NDArray asT(@Cast("nd4j::DataType") int dtype); public native void linspace(double start); @@ -4554,17 +4435,15 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native double getTrace(); - public native ResultSet multipleTensorsAlongDimension(@StdVector IntPointer indices, @StdVector IntPointer dimensions); - public native ResultSet multipleTensorsAlongDimension(@StdVector IntBuffer indices, @StdVector IntBuffer dimensions); - public native ResultSet multipleTensorsAlongDimension(@StdVector int[] indices, @StdVector int[] dimensions); + public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector IntPointer indices, @StdVector IntPointer dimensions); + public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector IntBuffer indices, @StdVector IntBuffer dimensions); + public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector int[] indices, @StdVector int[] dimensions); - public native ResultSet allTensorsAlongDimension(@StdVector IntPointer dimensions); - public native ResultSet allTensorsAlongDimension(@StdVector IntBuffer dimensions); - public native ResultSet allTensorsAlongDimension(@StdVector int[] dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntPointer dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntBuffer dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector int[] dimensions); - //ResultSet allTensorsAlongDims(const std::vector& dimensions) const; - - public native ResultSet allExamples(); + public native @ByVal ResultSet allExamples(); /** * set _shapeInfo @@ -4672,7 +4551,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); /** * returns true if these two NDArrays have same rank, dimensions, strides, ews and order */ - public native @Cast("bool") boolean isSameShapeStrict(@Const NDArray other); + public native @Cast("bool") boolean isSameShapeStrict(@Const @ByRef NDArray other); /** * returns true if buffer && shapeInfo were defined (non nullptr) @@ -4731,11 +4610,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, @Const @ByRef NDArray value); - /** - * creates array which points on certain sub-range of this array, sub-range is defined by given indices - */ - - /** * returns true if array is 2D */ @@ -4806,59 +4680,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native @Cast("bool") boolean isS(); - /** - * inline accessing operator for matrix, i - absolute index - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i) const; - - /** - * inline modifying operator for matrix, i - absolute index - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i); - - /** - * inline accessing operator for 2D array, i - row, j - column - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i, const Nd4jLong j) const; - - /** - * inline modifying operator for 2D array, i - row, j - column - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i, const Nd4jLong j); - - /** - * inline accessing operator for 3D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; - - /** - * inline modifying operator for 3D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k); - - /** - * inline modifying operator for 4D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w); - - /** - * inline accessing operator for 4D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) const; - - /** - * inline modifying operator for ND array - * idx - array with corresponding indexes, for example {2,10,0,5,...,8}, number of indexes should be equal to array rank - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong* idx); - - /** - * inline accessing operator for ND array - * idx - array with corresponding indexes, for example {2,10,0,5,...,8}, number of indexes should be equal to array rank - */ - //FORCEINLINE NDArray operator()(const Nd4jLong* idx) const; - - public native @Cast("bool") boolean isAttached(); public native NDArray detach(); @@ -4874,268 +4695,75 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - - - ////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////// -// accessing operator for matrix, i - absolute index -/* -NDArray NDArray::operator()(const Nd4jLong i) const { - if (i >= shape::length(_shapeInfo)) - throw std::invalid_argument("NDArray::operator(i): input index is out of array length !"); - - auto ews = shape::elementWiseStride(_shapeInfo); - char order = ordering(); - - if(ews == 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else if(ews > 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * ews * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else { - Nd4jLong idx[MAX_RANK]; - shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } -} -*/ -////////////////////////////////////////////////////////////////////////// -// modifying operator for matrix, i - absolute index -/* -NDArray& NDArray::operator()(const Nd4jLong i) { - if (i >= shape::length(_shapeInfo)) - throw std::invalid_argument("NDArray::operator(i): input index is out of array length !"); - - auto ews = shape::elementWiseStride(_shapeInfo); - auto order = ordering(); - - if(ews == 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - // FIXME: bad - return result; - } else if(ews > 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * ews * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else { - Nd4jLong idx[MAX_RANK]; - shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } -}*/ ////////////////////////////////////////////////////////////////////////// -// accessing operator for 2D matrix, i - row, j - column -/* -NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j) const { - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); - - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - // TODO: do we really want a view here? - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ -////////////////////////////////////////////////////////////////////////// -// modifying operator for 2D matrix, i - row, j - column -/* -NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j) { - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); - - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - //FIXME: bad, will crash! - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -// accessing operator for 3D array, i - row, j - column -/* -NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || j >= shapeOf()[2]) - throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); - - Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -// modifying operator for 3D array -/* -NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) - throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); - Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - //FIXME: bad, will crash! - return result; -} -*/ -/* -NDArray NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) const { - - if (rankOf() != 4 || t >= shapeOf()[0] || u >= shapeOf()[1] || v >= shapeOf()[2] || w >= shapeOf()[3]) - throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); - - Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ -/* -NDArray& NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) { - - if (rankOf() != 4 || t >= shapeOf()[0] || u >= shapeOf()[1] || v >= shapeOf()[2] || w >= shapeOf()[3]) - throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); - - Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - // FIXME - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -/* -NDArray NDArray::operator()(const Nd4jLong* idx) const { - for(int i = 0; i < rankOf(); ++i) - if (idx[i] >= sizeAt(i)) - throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -/* -NDArray& NDArray::operator()(const Nd4jLong* idx) { - - for(int i = 0; i < rankOf(); ++i) - if (idx[i] >= sizeAt(i)) - throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - // FIXME - return result; -} -*/ - ////////////////////////////////////////////////////////////////////////// - +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// + + + +////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////// +// still the definition of inline function must be in header file - ////////////////////////////////////////////////////////////////////////// - // still the definition of inline function must be in header file - ////////////////////////////////////////////////////////////////////////// @@ -5254,7 +4882,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include -// #include +// #include // #include // #include // #include @@ -5405,6 +5033,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include // #include +// #include // #ifdef __CUDACC__ // #endif @@ -5856,7 +5485,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { //#include // #include -// #include +// #include // #include // #include // #include @@ -5947,7 +5576,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include // #include -// #include +// #include // #include // #include // #include @@ -6063,7 +5692,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include // #include -// #include +// #include // #include // #include // #include @@ -6619,6 +6248,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include // #include +// #include // CUDA-specific includes // #ifdef __CUDACC__ @@ -6669,12 +6299,13 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // this method returns workspace for object allocations public native Workspace oWorkspace(); - public native void setVariableSpace(VariableSpace variableSpace); public native RandomBuffer getRNG(); public native void setRNG(RandomBuffer rng); + public native void setTargetEngine(@Cast("samediff::Engine") int engine); + public native VariableSpace getVariableSpace(); public native LaunchContext launchContext(); @@ -6756,10 +6387,12 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native void setInputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); public native void setInputArray(int index, NDArray array); public native void setInputArray(int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + public native void setInputArray(int index, Pointer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setOutputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); public native void setOutputArray(int index, NDArray array); public native void setOutputArray(int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + public native void setOutputArray(int index, Pointer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setTArguments(DoublePointer arguments, int numberOfArguments); public native void setTArguments(DoubleBuffer arguments, int numberOfArguments); @@ -6781,9 +6414,11 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); - public native void allowHelpers(@Cast("bool") boolean reallyAllow); public native @Cast("bool") boolean helpersAllowed(); + + public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride); + public native @Cast("bool") boolean shapeFunctionOverride(); } @@ -6823,6 +6458,11 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include // #include +// #include + +// #ifndef __STANDALONE_BUILD__ +// #include +// #endif @Namespace("nd4j::graph") @NoOffset public static class ContextPrototype extends Pointer { static { Loader.load(); } @@ -6868,6 +6508,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native @Cast("bool*") @StdVector BooleanPointer getBArguments(); public native @StdVector IntPointer getAxis(); + public native @Cast("samediff::Engine") int engine(); + public native @Cast("size_t") long numT(); public native @Cast("size_t") long numI(); public native @Cast("size_t") long numB(); @@ -9959,7 +9601,7 @@ public static final int PREALLOC_SIZE = 33554432; // #define BROADCAST(NAME) nd4j::BroadcastOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) // #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) - +public static final int ALL_STRINGS =UTF32; public static final int ALL_INDICES =INT64; public static final int ALL_INTS =UINT64; public static final int ALL_FLOATS =BFLOAT16; @@ -11193,7 +10835,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #if defined(_MSC_VER) || defined(_WIN64) || defined(_WIN32) || defined(__CLION_IDE__) || defined(__VSCODE__) // #define NOT_EXCLUDED(NAME) 1>0 // #else -// #define NOT_EXCLUDED(NAME) defined(LIBND4J_ALL_OPS) || defined(NAME) +// for now we don't want minifier mechanics working +//#define NOT_EXCLUDED(NAME) defined(LIBND4J_ALL_OPS) || defined(NAME) +// #define NOT_EXCLUDED(NAME) 1>0 // #endif // #ifdef __JAVACPP_HACK__ @@ -11520,6 +11164,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #define PARAMETRIC_D() [&] (Parameters &p) -> Context* + +// #ifdef __CUDABLAS__ +// #endif + // #endif @@ -11738,6 +11386,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #define SD_PLATFORMHELPER_H // #include +// #include // #include // #include // #include @@ -11753,6 +11402,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native @StdString BytePointer name(); + public native @Cast("samediff::Engine") int engine(); + public native @Cast("Nd4jLong") long hash(); /** @@ -12258,10 +11909,11 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include -// #include +// #include // #include // #include // #include +// #include // handlers part // #include @@ -12299,13 +11951,13 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native void registerHelper(PlatformHelper op); - public native @Cast("bool") boolean hasHelper(@Cast("Nd4jLong") long hash); + public native @Cast("bool") boolean hasHelper(@Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); public native DeclarableOp getOperation(@Cast("char*") String name); public native DeclarableOp getOperation(@Cast("char*") BytePointer name); public native DeclarableOp getOperation(@Cast("Nd4jLong") long hash); - public native PlatformHelper getPlatformHelper(@Cast("Nd4jLong") long hash); + public native PlatformHelper getPlatformHelper(@Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); public native @Cast("Nd4jLong*") @StdVector LongPointer getAllHashes(); @@ -12367,7 +12019,11 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include +// #include +// #include // #include +// #include // #include // #include // #include @@ -14514,6 +14170,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public Pow() { super((Pointer)null); allocate(); } private native void allocate(); } + @Namespace("nd4j::ops") public static class Pow_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Pow_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Pow_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Pow_bp position(long position) { + return (Pow_bp)super.position(position); + } + + public Pow_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** @@ -17115,19 +16786,20 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * This operation calculates hash code, optionally along dimension */ // #if NOT_EXCLUDED(OP_hashcode) - @Namespace("nd4j::ops") public static class hashcode extends DeclarableReductionOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hashcode(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hashcode(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hashcode position(long position) { - return (hashcode)super.position(position); - } - + @Namespace("nd4j::ops") public static class hashcode extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hashcode(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hashcode(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hashcode position(long position) { + return (hashcode)super.position(position); + } + public hashcode() { super((Pointer)null); allocate(); } private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif @@ -17160,7 +16832,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2019 Konduit K.K. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -17458,6 +17130,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * Input : batched tensor with rank >=2 * Output: tensor with rank lesser by 1 from input */ +// #if NOT_EXCLUDED(OP_matrix_diag_part) @Namespace("nd4j::ops") public static class matrix_diag_part extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -17473,7 +17146,36 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } +// #endif + /** + * QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper triangular. + * For A (MxN) Q is M x M and R is (NxN). + * + * Input : + * 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of float matricies + * + * Output: + * 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies {Qs} + * 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular matricies {Rs} + */ +// #if NOT_EXCLUDED(OP_qr) + @Namespace("nd4j::ops") public static class qr extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public qr(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public qr(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public qr position(long position) { + return (qr)super.position(position); + } + + public qr() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif /** * This operation takes 2 arrays: original values, and values to be excluded. And returns 2 arrays: values left after exclusion, and indices in original array for surivals. @@ -18313,6 +18015,34 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This op calculates lgamma function lgamma(x) = log(Gamma(x)) + * + * Input arrays: + * 0: x - input matrix + * + * Output array: + * 0: log of Gamma(x) + * + */ +// #if NOT_EXCLUDED(OP_lgamma) + @Namespace("nd4j::ops") public static class lgamma extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lgamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lgamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lgamma position(long position) { + return (lgamma)super.position(position); + } + + public lgamma() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) * @@ -19345,6 +19075,71 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * triangular_solve op. - reverse Gaussian method for solve systems of linear equations. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations + * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations + * + * boolean args: + * 0 - lower - default is true (optional) - left part is lower triangular matrix + * 1 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ +// #if NOT_EXCLUDED(OP_triangular_solve) + @Namespace("nd4j::ops") public static class triangular_solve extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public triangular_solve(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public triangular_solve(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public triangular_solve position(long position) { + return (triangular_solve)super.position(position); + } + + public triangular_solve() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * lu op. - make LUP decomposition of given batch of 2D square matricies + * + * input params: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M matricies in it + * 1 - int (32 or 64) batched vector of permutations with length M - shape (x * y * z * ::: * M) + * + * int argument: + * 0 - data type of output permutaion vector (int32 or int64), optional, default INT32 + */ + +// #if NOT_EXCLUDED(OP_matrix_inverse) + @Namespace("nd4j::ops") public static class lu extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lu position(long position) { + return (lu)super.position(position); + } + + public lu() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] * @@ -20776,6 +20571,41 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This op make area interpolated resize (as OpenCV INTER_AREA algorithm) for given tensor + * + * input array: + * 0 - images - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - size - 1D-Tensor with 2 values (newWidth, newHeight) (if missing a pair of integer args should be provided). + * + * int args: - proveded only when size tensor is missing + * 0 - new height + * 1 - new width + * boolean args: + * 0 - align_corners - optional (default is false) + * + * output array: + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) + * + */ +// #if NOT_EXCLUDED(OP_resize_area) + @Namespace("nd4j::ops") public static class resize_area extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public resize_area(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public resize_area(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public resize_area position(long position) { + return (resize_area)super.position(position); + } + + public resize_area() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * This op make interpolated resize for given tensor with given algorithm. * Supported algorithms are bilinear, bicubic, nearest_neighbor. @@ -21544,6 +21374,36 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif + /* + * multinomial (categorical) random generator draws samples from a multinomial distribution + * + * Input array: + * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size (N), num_classes (K)] + * 1 - array with one int value of samples number, number of independent samples to draw for each experiment 1,N. + * Int arguments: + * 0 - optional argument, corresponds to dimension with batch_size + * 1 - optional argument, integer type to use for the output. Default int64. + * + * Output array: + * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] + */ +// #if NOT_EXCLUDED(OP_random_multinomial) + @Namespace("nd4j::ops") public static class random_multinomial extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_multinomial(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_multinomial(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_multinomial position(long position) { + return (random_multinomial)super.position(position); + } + + public random_multinomial() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif // #if NOT_EXCLUDED(OP_random_normal) @Namespace("nd4j::ops") public static class random_normal extends DeclarableCustomOp { @@ -23881,7 +23741,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #ifndef DEV_TESTS_SHAPEDESCRIPTOR_H // #define DEV_TESTS_SHAPEDESCRIPTOR_H -// #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index 9d067b5bc..c2fca8d89 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -36,10 +36,12 @@ import java.util.Scanner; value = {@Platform(define = "LIBND4J_ALL_OPS", include = { "memory/MemoryType.h", "array/DataType.h", + "array/DataBuffer.h", "array/ConstantDataBuffer.h", "array/ConstantDescriptor.h", "array/TadPack.h", "execution/ErrorReference.h", + "execution/Engine.h", "Environment.h", "types/utf8string.h", "NativeOps.h", @@ -160,6 +162,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled { .put(new Info("OpaqueVariablesSet").pointerTypes("OpaqueVariablesSet")) .put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable")) .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) + .put(new Info("OpaqueDataBuffer").pointerTypes("OpaqueDataBuffer")) .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) .put(new Info("OpaqueLaunchContext").pointerTypes("OpaqueLaunchContext")) @@ -186,6 +189,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled { .put(new Info("std::pair").pointerTypes("IntIntPair").define()) .put(new Info("std::vector >").pointerTypes("IntVectorVector").define()) .put(new Info("std::vector >").pointerTypes("LongVectorVector").define()) + .put(new Info("std::vector").pointerTypes("ConstNDArrayVector").define()) .put(new Info("std::vector").pointerTypes("NDArrayVector").define()) .put(new Info("nd4j::graph::ResultWrapper").base("org.nd4j.nativeblas.ResultWrapperAbstraction").define()) .put(new Info("bool").cast().valueTypes("boolean").pointerTypes("BooleanPointer", "boolean[]")) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/resources/nd4j-native.properties b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/resources/nd4j-native.properties index 4690d54f6..0b2489b53 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/resources/nd4j-native.properties +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/resources/nd4j-native.properties @@ -28,7 +28,7 @@ native.ops= org.nd4j.nativeblas.Nd4jCpu ndarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory ndarray.order = c resourcemanager_state = false -databufferfactory = org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory +databufferfactory = org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory workspacemanager = org.nd4j.linalg.cpu.nativecpu.workspace.CpuWorkspaceManager alloc = javacpp fft = org.nd4j.linalg.fft.DefaultFFTInstance diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml index bc468c874..5f5c5fa90 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml @@ -87,6 +87,13 @@ ${logback.version} test + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java index c959edbfe..81599cd5b 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java @@ -20,6 +20,7 @@ package org.nd4j.tensorflow.conversion; import junit.framework.TestCase; import org.apache.commons.io.FileUtils; import org.bytedeco.tensorflow.TF_Tensor; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.resources.Resources; import org.nd4j.shade.protobuf.Descriptors; @@ -46,7 +47,17 @@ import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -public class GraphRunnerTest { +public class GraphRunnerTest extends BaseND4JTest { + + @Override + public DataType getDataType() { + return DataType.FLOAT; + } + + @Override + public DataType getDefaultFPDataType() { + return DataType.FLOAT; + } public static ConfigProto getConfig(){ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java index eb263f119..fbf4249bd 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java @@ -18,6 +18,7 @@ package org.nd4j.tensorflow.conversion; import org.apache.commons.io.IOUtils; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; @@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; -public class TensorflowConversionTest { +public class TensorflowConversionTest extends BaseND4JTest { @Test public void testView() { diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java index 614330813..b13ca465f 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java @@ -17,6 +17,7 @@ package org.nd4j.tensorflow.conversion; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.protobuf.util.JsonFormat; import org.apache.commons.io.IOUtils; import org.junit.Test; @@ -37,7 +38,7 @@ import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -public class GpuGraphRunnerTest { +public class GpuGraphRunnerTest extends BaseND4JTest { @Test public void testGraphRunner() throws Exception { diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index e6861d257..9d098189f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -127,6 +127,13 @@ + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java index b4bb6e1a4..1acc013c2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java @@ -63,8 +63,8 @@ import static org.junit.Assume.assumeFalse; TransformOpValidation.class, //TF import tests - TFGraphTestAllSameDiff.class, - TFGraphTestAllLibnd4j.class + TFGraphTestAllSameDiff.class + //TFGraphTestAllLibnd4j.class }) //IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test" // With it ignored here, the individual tests will run outside (i.e., separately/independently) of the suite in both "mvn test" and IntelliJ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java index 09e94acc7..56acfa828 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.AbstractSession; +import org.nd4j.autodiff.samediff.internal.FrameIter; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; @@ -112,7 +113,6 @@ public class TestSessions extends BaseNd4jTest { m.put("x", x); m.put("y", y); - System.out.println("----------------------------------"); Map outMap = is.output(Collections.singletonList("d"), m, null, Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); @@ -144,7 +144,6 @@ public class TestSessions extends BaseNd4jTest { m.put("x", x); m.put("y", y); - System.out.println("----------------------------------"); InferenceSession is = new InferenceSession(sd); // String outName = merge.name(); String outName = outVar.name(); @@ -183,14 +182,14 @@ public class TestSessions extends BaseNd4jTest { InferenceSession is = new InferenceSession(sd); String n = merge.name(); - System.out.println("----------------------------------"); +// System.out.println("----------------------------------"); Map outMap = is.output(Collections.singletonList(n), m, null, Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); assertEquals(expTrue, outMap.get(n)); - System.out.println("----------------------------------"); +// System.out.println("----------------------------------"); //Check false case: bArr.assign(0); is = new InferenceSession(sd); @@ -217,9 +216,10 @@ public class TestSessions extends BaseNd4jTest { File f = new ClassPathResource("tf_graphs/examples/while1/iter_" + numIter + "/frozen_model.pb").getFile(); SameDiff sd = TFGraphMapper.importGraph(f); - System.out.println(sd.summary()); +// System.out.println(sd.summary()); + sd.summary(); - System.out.println("----------------------------------"); +// System.out.println("----------------------------------"); //This particular test/graph doesn't use placeholders InferenceSession is = new InferenceSession(sd); is.setMmgr(new NoOpMemoryMgr()); //So arrays aren't deallocated during execution @@ -239,17 +239,17 @@ public class TestSessions extends BaseNd4jTest { //Some sanity checks on the internal state: //Check 1: "while/Less" should be executed numIter+1 times... i.e., numIter times through the loop, plus once to exit for( int i=0; i NO_BP_YET = new HashSet<>(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 7f6daf78f..e02e4b91d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -1648,8 +1648,8 @@ public class MiscOpValidation extends BaseOpValidation { INDArray vArr = gm.get(v.name()); INDArray wArr = gm.get(w.name()); - System.out.println(vArr); - System.out.println(wArr); +// System.out.println(vArr); +// System.out.println(wArr); assertEquals(Nd4j.zeros(DataType.DOUBLE, 3, 4), wArr); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java index 0d74f07f3..16f097084 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java @@ -30,8 +30,6 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.nativeblas.NativeOpsHolder; -import java.util.Arrays; - import static org.junit.Assert.assertNull; @Slf4j @@ -179,8 +177,8 @@ public class ReductionBpOpValidation extends BaseOpValidation { @Test public void testMeanBP_Rank1() { INDArray dLdOut = Nd4j.scalar(0.5); - INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3}); - INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5/3); + INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); + INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5 / 3); INDArray dLdIn = Nd4j.createUninitialized(new long[]{3}); @@ -199,7 +197,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { for (boolean keepDims : new boolean[]{false, true}) { long[] reducedShape_0 = (keepDims ? new long[]{1, 4} : new long[]{4}); - INDArray preReduceInput = Nd4j.linspace(1, 12, 12).reshape('c',3, 4); + INDArray preReduceInput = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray dLdOut_0 = Nd4j.create(new double[]{1, 2, 3, 4}, reducedShape_0); INDArray dLdInExpected_0 = Nd4j.createUninitialized(preReduceInput.shape()); for (int i = 0; i < 3; i++) { @@ -524,7 +522,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { @Test public void testStdevBP_Rank1() { INDArray dLdOut = Nd4j.scalar(0.5); - INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3}); + INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); double stdev = preReduceInput.stdNumber(true).doubleValue(); double mean = preReduceInput.meanNumber().doubleValue(); @@ -532,8 +530,8 @@ public class ReductionBpOpValidation extends BaseOpValidation { .subi(mean).divi(stdev * 2) .muli(0.5); //* dL/dOut - System.out.println(dLdInExp.shapeInfoToString()); - System.out.println(Arrays.toString(dLdInExp.data().asFloat())); +// System.out.println(dLdInExp.shapeInfoToString()); +// System.out.println(Arrays.toString(dLdInExp.data().asFloat())); INDArray dLdIn = Nd4j.createUninitialized(new long[]{3}); @@ -577,7 +575,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { INDArray dLdInExpected_1 = preReduceInput.dup(); dLdInExpected_1.subiColumnVector(mean_1) .diviColumnVector(stdev_1.mul(divisor)) - .muliColumnVector(dLdOut_1.reshape(3,1)); + .muliColumnVector(dLdOut_1.reshape(3, 1)); dLdIn = Nd4j.createUninitialized(3, 4); err = OpValidation.validate(new OpTestCase(new StandardDeviationBp(preReduceInput, dLdOut_1, dLdIn, biasCorrected, keepDims, 1)) @@ -653,7 +651,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { INDArray mean_1 = preReduceInput.mean(1); INDArray dLdInExpected_1 = preReduceInput.dup(); dLdInExpected_1.subiColumnVector(mean_1).muli(2.0 / divisor) - .muliColumnVector(dLdOut_1.reshape(3,1)); + .muliColumnVector(dLdOut_1.reshape(3, 1)); dLdIn = Nd4j.createUninitialized(3, 4); @@ -688,17 +686,16 @@ public class ReductionBpOpValidation extends BaseOpValidation { // = cumSumExclusive(dL/dOut_j) - - for(boolean exclusive : new boolean[]{false, true}) { - for(boolean reverse : new boolean[]{false, true}) { + for (boolean exclusive : new boolean[]{false, true}) { + for (boolean reverse : new boolean[]{false, true}) { INDArray preReduceInput = Nd4j.linspace(1, 12, 12).reshape(3, 4); - INDArray dLdOut = Nd4j.valueArrayOf(new long[]{3,4}, 0.5); + INDArray dLdOut = Nd4j.valueArrayOf(new long[]{3, 4}, 0.5); INDArray dLdIn = Nd4j.createUninitialized(3, 4); INDArray dLdInExpected; - if(exclusive){ - if(reverse){ + if (exclusive) { + if (reverse) { dLdInExpected = Nd4j.create(new double[][]{ {0.0, 0.0, 0.0, 0.0}, {0.5, 0.5, 0.5, 0.5}, @@ -710,7 +707,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { {0.0, 0.0, 0.0, 0.0}}); } } else { - if(reverse){ + if (reverse) { dLdInExpected = Nd4j.create(new double[][]{ {0.5, 0.5, 0.5, 0.5}, {1.0, 1.0, 1.0, 1.0}, @@ -727,7 +724,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { String err = OpValidation.validate(new OpTestCase( new CumSumBp(preReduceInput, dLdOut, dLdIn, exclusive, reverse, 0)) .expectedOutput(0, dLdInExpected)); - if(err != null){ + if (err != null) { err = err + " - exclusive=" + exclusive + ", reverse=" + reverse; } assertNull(err); @@ -737,7 +734,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { @Test - public void testNorm2Bp(){ + public void testNorm2Bp() { //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * x/|x|_2 @@ -797,7 +794,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testNorm1Bp(){ + public void testNorm1Bp() { //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * sgn(in) @@ -856,7 +853,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testNormMaxBp(){ + public void testNormMaxBp() { //out = max_i (|in_i|) //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise) @@ -866,8 +863,8 @@ public class ReductionBpOpValidation extends BaseOpValidation { INDArray preReduceInput = Nd4j.linspace(-5, 6, 12).reshape(3, 4); INDArray sgn = Transforms.sign(preReduceInput, true); - INDArray max = Nd4j.create(3,4); - max.putScalar(2,3,1.0); + INDArray max = Nd4j.create(3, 4); + max.putScalar(2, 3, 1.0); INDArray dLdOut; if (keepDims) { @@ -896,7 +893,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { long[] reducedShape_0 = (keepDims ? new long[]{1, 4} : new long[]{4}); INDArray preReduceInput = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray sgn = Transforms.sign(preReduceInput, true); - INDArray max_0 = Nd4j.create(3,4); + INDArray max_0 = Nd4j.create(3, 4); max_0.getRow(2).assign(1.0); INDArray dLdOut_0 = Nd4j.create(new double[]{1, 2, 3, 4}, reducedShape_0); INDArray dLdInExpected_0 = sgn.mul(max_0).mulRowVector(dLdOut_0); @@ -910,7 +907,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { long[] reducedShape_1 = (keepDims ? new long[]{3, 1} : new long[]{3}); INDArray dLdOut_1 = Nd4j.create(new double[]{1, 2, 3}, reducedShape_1); - INDArray max_1 = Nd4j.create(3,4); + INDArray max_1 = Nd4j.create(3, 4); max_1.getColumn(3).assign(1.0); INDArray dLdInExpected_1 = sgn.mul(max_1).mulColumnVector(dLdOut_1); dLdIn = Nd4j.createUninitialized(3, 4); @@ -922,3 +919,4 @@ public class ReductionBpOpValidation extends BaseOpValidation { } } } + diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 802ed9be9..6c6ba5b83 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -54,8 +54,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; +import static org.junit.Assert.*; @Slf4j @RunWith(Parameterized.class) @@ -935,7 +934,7 @@ public class ReductionOpValidation extends BaseOpValidation { INDArray expOut; SDVariable reduced; String name; - System.out.println(i); +// System.out.println(i); switch (i) { case 0: reduced = sd.math().manhattanDistance(in, in2, reduceDims); @@ -970,7 +969,7 @@ public class ReductionOpValidation extends BaseOpValidation { default: throw new RuntimeException(); } - System.out.println(i + " - end"); +// System.out.println(i + " - end"); long[] expShape; @@ -1011,7 +1010,9 @@ public class ReductionOpValidation extends BaseOpValidation { @Test public void testReductionsBackwards() { - for (int i = 0; i < 7; i++) { +// for (int i = 0; i < 7; i++) { + int i=5; + { SameDiff sd = SameDiff.create(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index 988b8da69..8a4f8164a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -194,9 +194,9 @@ public class RnnOpValidation extends BaseOpValidation { INDArray out5 = Nd4j.create(new float[]{-0.17905743f, 0.19177397f}, new int[]{1,2}); //Cell state INDArray out6 = Nd4j.create(new float[]{-0.04025514f, 0.10104967f}, new int[]{1,2}); //Output - for(int i=0; i l = op.calculateOutputShape(); - System.out.println(Arrays.toString(l.get(0).getShape())); +// System.out.println(Arrays.toString(l.get(0).getShape())); assertArrayEquals(new long[]{4, 3}, l.get(0).getShape()); op = DynamicCustomOp.builder("permute") @@ -2382,7 +2382,7 @@ public class ShapeOpValidation extends BaseOpValidation { .addIntegerArguments(1, 0) .build(); l = op.calculateOutputShape(); - System.out.println(Arrays.toString(l.get(0).getShape())); +// System.out.println(Arrays.toString(l.get(0).getShape())); assertArrayEquals(new long[]{4, 3}, l.get(0).getShape()); @@ -2391,7 +2391,7 @@ public class ShapeOpValidation extends BaseOpValidation { Nd4j.createFromArray(1, 2, 0)) .build(); l = op.calculateOutputShape(); - System.out.println(Arrays.toString(l.get(0).getShape())); +// System.out.println(Arrays.toString(l.get(0).getShape())); assertArrayEquals(new long[]{4, 5, 3}, l.get(0).getShape()); } @@ -2419,7 +2419,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray permute = Nd4j.createFromArray(1,0); - System.out.println(in); +// System.out.println(in); SameDiff sd = SameDiff.create(); SDVariable v = sd.var(in); @@ -2434,8 +2434,6 @@ public class ShapeOpValidation extends BaseOpValidation { @Test public void testPermute4(){ - Nd4j.getExecutioner().enableDebugMode(true); - Nd4j.getExecutioner().enableVerboseMode(true); INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray permute = Nd4j.createFromArray(1,0); @@ -2457,8 +2455,8 @@ public class ShapeOpValidation extends BaseOpValidation { DynamicCustomOp op = b.build(); Nd4j.exec(op); - System.out.println(in); - System.out.println(op.outputArguments()[0]); +// System.out.println(in); +// System.out.println(op.outputArguments().get(0)); assertEquals(exp, op.getOutputArgument(0)); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 6a42d21e1..154201a35 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.functions.DifferentialFunction; @@ -938,9 +939,7 @@ public class TransformOpValidation extends BaseOpValidation { tc.expectedOutput(t.name(), Transforms.min(ia, 0.5, true)); break; case 65: - t = sd.assign(in, 0.5); - tc.expectedOutput(t.name(), ia.dup().assign(0.5)); - break; + continue; // assign op was removed. case 66: t = sd.scalarFloorMod(in, 0.5); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5))); @@ -1180,9 +1179,7 @@ public class TransformOpValidation extends BaseOpValidation { tc.expectedOutput(t.name(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); break; case 18: - t = sd.assign(in1, in2); - tc.expectedOutput(t.name(), ib); - break; + continue; //assign op was removed. case 19: t = sd.math().atan2(in1, in2); tc.expectedOutput(t.name(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms @@ -1465,6 +1462,7 @@ public class TransformOpValidation extends BaseOpValidation { } + @Ignore("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540") @Test public void testPad(){ INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java index a6f7b6bea..75f1615d7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java @@ -22,10 +22,21 @@ import static org.junit.Assert.fail; import org.junit.Assert; import org.junit.Test; +import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; +import org.nd4j.linalg.factory.Nd4jBackend; -public class ConvConfigTests { +public class ConvConfigTests extends BaseNd4jTest { + + public ConvConfigTests(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering() { + return 'c'; + } @Test public void testDeConv2D(){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/LogisticPredictions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/LogisticPredictions.java deleted file mode 100644 index 8eb63d57e..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/LogisticPredictions.java +++ /dev/null @@ -1,40 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.autodiff.samediff; - -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.util.Map; - -public class LogisticPredictions implements SameDiffFunctionDefinition { - /** - * @param sameDiff - * @param inputs - * @param variableInputs - * @return - */ - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable input = sameDiff.var("x",inputs.get("x")); - SDVariable w = sameDiff.var("w",inputs.get("w")); - SDVariable y = sameDiff.var("y",inputs.get("y")); - SDVariable preOutput = sameDiff.mmul(input,w); - SDVariable sigmoid = sameDiff.nn().sigmoid(preOutput); - - return new SDVariable[]{sigmoid}; - } -} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java index 303739ea1..a08a390a9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java @@ -90,7 +90,8 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest { SDVariable loss1 = add.std("l1", true); SDVariable loss2 = mmul.mean("l2"); - System.out.println(sd.summary()); +// System.out.println(sd.summary()); + sd.summary(); if(i == 0){ sd.setLossVariables("l1", "l2"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index db8d7d551..878289beb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -709,9 +709,9 @@ public class SameDiffTests extends BaseNd4jTest { val s = in2.add(5.0); Map map = sd.outputAll(null); - log.info("Result M: {}", map.get(m.name())); - log.info("Result F: {}", map.get(f.name())); - log.info("Result S: {}", map.get(s.name())); +// log.info("Result M: {}", map.get(m.name())); +// log.info("Result F: {}", map.get(f.name())); +// log.info("Result S: {}", map.get(s.name())); } @Test @@ -1654,7 +1654,6 @@ public class SameDiffTests extends BaseNd4jTest { INDArray expOut = Nd4j.create(DataType.BOOL, ia.shape()); Nd4j.exec(new IsStrictlyIncreasing(ia, expOut)); - System.out.println(expOut); } @Test @@ -1997,8 +1996,6 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable varIndices = sd.constant("indices", indices); SDVariable gather = sd.gather(var, varIndices, 0); - System.out.println(in); - INDArray exp = Nd4j.pullRows(in, 1, new int[]{0, 1, 5}); //Along dimension 1 -> equiv to "indexes for axis 0" INDArray act = gather.eval(); @@ -2020,8 +2017,6 @@ public class SameDiffTests extends BaseNd4jTest { Nd4j.exec(op); - System.out.println(out); - INDArray exp = Nd4j.pullRows(in, 1, new int[]{0, 1, 5}); //Along dimension 1 == indexes for dimension 0 assertEquals(exp, out); @@ -2396,13 +2391,14 @@ public class SameDiffTests extends BaseNd4jTest { Map phMap = new HashMap<>(); phMap.put(fn.getGradPlaceholderName(), grad); - log.info("--------------- out.eval() ---------------"); +// log.info("--------------- out.eval() ---------------"); out.eval(); - log.info("--------------- sd.execBackwards() #1 ---------------"); +// log.info("--------------- sd.execBackwards() #1 ---------------"); sd.calculateGradients(phMap, "in", "W", "b"); - log.info("--------------- sd.execBackwards() #2 ---------------"); - System.out.println(sd.getFunction("grad").summary()); +// log.info("--------------- sd.execBackwards() #2 ---------------"); +// System.out.println(sd.getFunction("grad").summary()); + sd.getFunction("grad").summary(); in.setArray(Nd4j.linspace(1, 10, 10).reshape(2, 5)); grad = Nd4j.linspace(1, 8, 8).reshape(2, 4); @@ -3232,7 +3228,8 @@ public class SameDiffTests extends BaseNd4jTest { Map secondBranch = Maps.newHashMap(); secondBranch.put("a", Nd4j.createFromArray(7.0)); - System.out.println(sd.summary()); +// System.out.println(sd.summary()); + sd.summary(); INDArray outArr = sd.output(secondBranch, "out").get("out"); assertEquals(Nd4j.createFromArray(14.0), outArr); @@ -3429,11 +3426,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable rand1 = sd1.var("random", new UniformInitScheme('c', 3), DataType.FLOAT, 3, 1); - Nd4j.getRandom().setSeed(0); - System.out.println(rand0.eval()); - - Nd4j.getRandom().setSeed(0); - System.out.println(rand1.eval()); +// Nd4j.getRandom().setSeed(0); +// System.out.println(rand0.eval()); +// +// Nd4j.getRandom().setSeed(0); +// System.out.println(rand1.eval()); INDArray a0 = rand0.eval(); Nd4j.getRandom().setSeed(0); @@ -3520,4 +3517,19 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(config, fromJson); } } + + @Test + public void testRngSanityCheck(){ + Nd4j.getRandom().setSeed(12345); + for(DataType dt : DataType.values()) { + if (!dt.isNumerical()) + continue; + SameDiff sameDiff = SameDiff.create(); + INDArray indaShape = Nd4j.createFromArray(3, 10); + SDVariable sdShape = sameDiff.constant(indaShape); + SDVariable random = sameDiff.random().uniform("data", 0.0, 10.0, sdShape, dt); + INDArray out = random.eval(); + String s = out.toString(); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java index cf99ebbaa..33ac52f16 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java @@ -63,8 +63,12 @@ public class CheckpointListenerTest extends BaseNd4jTest { return sd; } - public static DataSetIterator getIter(){ - return new IrisDataSetIterator(15, 150); + public static DataSetIterator getIter() { + return getIter(15, 150); + } + + public static DataSetIterator getIter(int batch, int totalExamples){ + return new IrisDataSetIterator(batch, totalExamples); } @@ -125,7 +129,7 @@ public class CheckpointListenerTest extends BaseNd4jTest { boolean[] found = new boolean[names.size()]; for(File f : files){ String s = f.getAbsolutePath(); - System.out.println(s); +// System.out.println(s); for( int i=0; i ph = new HashMap<>(); + ph.put("in", i); + + for( int x=0; x<10; x++ ) { + sd.outputSingle(ph, "predictions"); + } + + String content = FileUtils.readFileToString(f, StandardCharsets.UTF_8); +// System.out.println(content); + assertFalse(content.isEmpty()); + + //Should be 2 begins and 2 ends for each entry + //5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name + String[] opNames = {"mmul", "add", "softmax"}; + for(String s : opNames){ + assertEquals(s, 10, StringUtils.countMatches(content, s)); + } + + + System.out.println("///////////////////////////////////////////"); + ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.SAMEDIFF); + + } + + /* + @Test + public void testLoadTfProfile(){ + File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json"); + ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); + } + + @Test + public void testLoadTfProfileDir(){ + File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles"); + ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); + } + + @Test + public void testLoadTfProfileDir2(){ + File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0"); + ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); + } + */ +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java index 44b465fd3..b02a23771 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java @@ -13,10 +13,12 @@ import org.nd4j.graph.UIEvent; import org.nd4j.graph.UIGraphStructure; import org.nd4j.graph.UIStaticInfoRecord; import org.nd4j.graph.ui.LogFileWriter; +import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.primitives.Pair; @@ -27,7 +29,16 @@ import java.util.Map; import static org.junit.Assert.*; -public class UIListenerTest { +public class UIListenerTest extends BaseNd4jTest { + + public UIListenerTest(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering() { + return 'c'; + } @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java index 7ec752a14..fcdb04e31 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Random; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java index df27adf78..4ffaf3162 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java @@ -66,7 +66,7 @@ public class EvalJsonTest extends BaseNd4jTest { @Test public void testSerde() { - boolean print = true; + boolean print = false; Nd4j.getRandom().setSeed(12345); Evaluation evaluation = new Evaluation(); @@ -117,7 +117,7 @@ public class EvalJsonTest extends BaseNd4jTest { @Test public void testSerdeExactRoc() { Nd4j.getRandom().setSeed(12345); - boolean print = true; + boolean print = false; ROC roc = new ROC(0); ROCBinary roc2 = new ROCBinary(0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java index 2aac466d1..7fe4900cc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java @@ -225,8 +225,10 @@ public class EvalTest extends BaseNd4jTest { Evaluation evaluation2 = new Evaluation(); evaluation2.evalTimeSeries(labels2, predicted2, labelsMask); - System.out.println(evaluation.stats()); - System.out.println(evaluation2.stats()); +// System.out.println(evaluation.stats()); +// System.out.println(evaluation2.stats()); + evaluation.stats(); + evaluation2.stats(); assertEquals(evaluation.accuracy(), evaluation2.accuracy(), 1e-12); assertEquals(evaluation.f1(), evaluation2.f1(), 1e-12); @@ -369,7 +371,8 @@ public class EvalTest extends BaseNd4jTest { eval.eval(one, one); eval.eval(zero, zero); - System.out.println(eval.stats()); +// System.out.println(eval.stats()); + eval.stats(); assertEquals(0.75, eval.accuracy(), 1e-6); assertEquals(4, eval.getNumRowCounter()); @@ -389,10 +392,8 @@ public class EvalTest extends BaseNd4jTest { e.eval(1, 0); e.eval(1, 1); - System.out.println(e.stats()); - - char c = "\uFFFD".toCharArray()[0]; - System.out.println(c); +// System.out.println(e.stats()); + e.stats(); assertFalse(e.stats().contains("\uFFFD")); } @@ -431,8 +432,10 @@ public class EvalTest extends BaseNd4jTest { assertEquals(1, cm.getCount(3, 3)); assertEquals(2, cm.getCount(3, 0)); - System.out.println(e1.stats()); - System.out.println(e2.stats()); +// System.out.println(e1.stats()); +// System.out.println(e2.stats()); + e1.stats(); + e2.stats(); assertEquals(e1.stats(), e2.stats()); } @@ -494,7 +497,8 @@ public class EvalTest extends BaseNd4jTest { assertEquals(6, e.getTopNCorrectCount()); assertEquals(8, e.getTopNTotalCount()); - System.out.println(e.stats()); +// System.out.println(e.stats()); + e.stats(); } @@ -888,10 +892,11 @@ public class EvalTest extends BaseNd4jTest { assertEquals(exp, s); - System.out.println("============================"); - System.out.println(e.stats()); +// System.out.println("============================"); +// System.out.println(e.stats()); + e.stats(); - System.out.println("\n\n\n\n"); +// System.out.println("\n\n\n\n"); //Test with 21 classes (> threshold) e = new Evaluation(); @@ -899,10 +904,12 @@ public class EvalTest extends BaseNd4jTest { class0.putScalar(0, 1); e.eval(class0, class0); - System.out.println(e.stats()); +// System.out.println(e.stats()); + e.stats(); - System.out.println("\n\n\n\n"); - System.out.println(e.stats(false, true)); +// System.out.println("\n\n\n\n"); +// System.out.println(e.stats(false, true)); + e.stats(false, true); } @Test @@ -1033,7 +1040,7 @@ public class EvalTest extends BaseNd4jTest { e1.eval(one, one); String s1 = e1.stats(); - System.out.println(s1); +// System.out.println(s1); e1.reset(); e1.eval(zero, zero); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java index c864f6004..62679ef6a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java @@ -240,7 +240,8 @@ public class EvaluationBinaryTest extends BaseNd4jTest { EvaluationBinary eb = new EvaluationBinary(4, 30); eb.eval(l1, p1); - System.out.println(eb.stats()); +// System.out.println(eb.stats()); + eb.stats(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java index 012ba3434..64281c7d6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java @@ -34,8 +34,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Created by Alex on 05/07/2017. diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java index c4ffe13d2..e90a4829c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java @@ -33,8 +33,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Created by Alex on 04/11/2016. diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java index 1bd6fd22c..b95d5c974 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java @@ -72,7 +72,8 @@ public class RegressionEvalTest extends BaseNd4jTest { eval.eval(rand, rand); } - System.out.println(eval.stats()); +// System.out.println(eval.stats()); + eval.stats(); for (int i = 0; i < nCols; i++) { assertEquals(0.0, eval.meanSquaredError(i), 1e-6); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java index c2ca2148e..5ac5c9410 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java @@ -33,6 +33,7 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.Arrays; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @Slf4j diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java index d320ad6e3..88298b62b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java @@ -76,7 +76,7 @@ public class ExecutionTests extends BaseNd4jTest { Nd4j.create(1); val tg = TFGraphMapper.importGraphTxt(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream(), null, null); - System.out.println(tg.summary()); +// System.out.println(tg.summary()); Map result_0 = tg.outputAll(null); val exp_0 = Nd4j.create(DataType.FLOAT, 3).assign(3.0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index eae14b230..718a95d5a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -26,6 +26,7 @@ import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; import org.nd4j.autodiff.execution.NativeGraphExecutioner; + import org.nd4j.autodiff.execution.conf.ExecutionMode; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; @@ -228,9 +229,9 @@ public class TFGraphTestAllHelper { String s1 = s.format(tfPred, false); String s2 = s.format(nd4jPred, false); System.out.print("TF: "); - System.out.println(s1); + System.out.println(tfPred.toStringFull()); System.out.print("SD: "); - System.out.println(s2); + System.out.println(nd4jPred.toStringFull()); } } assertTrue("Predictions do not match on " + modelName + ", node " + outputNode, eq); @@ -285,8 +286,7 @@ public class TFGraphTestAllHelper { + " with minAbsError=" + minAbsErrorOverride + "; largest observed relError=" + maxRE, 0, countExceeds); } } - log.info("\n\tTEST {} PASSED with {} arrays compared...", modelName, predictions.keySet().size()); - log.info("\n========================================================\n"); + log.info("TEST {} PASSED with {} arrays compared...", modelName, predictions.keySet().size()); } //Serialize and deserialize, check equality: @@ -392,7 +392,7 @@ public class TFGraphTestAllHelper { public static Pair> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map inputs, ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners, Set requiredOutputs, boolean printArraysDebugging) throws IOException { - log.info("\n\tRUNNING TEST " + modelName + "..."); + log.info("RUNNING TEST {}...", modelName); SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName); if(listeners != null){ graph.setListeners(listeners); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 9e3db5b1a..63571a30e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -89,12 +89,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a "rnn/bstack/d_.*", //2019/05/21 - Failing on AVX2/512 intermittently (Linux, OSX), passing elsewhere - "unsorted_segment/.*", + //"unsorted_segment/.*", //2019/05/21 - Failing on windows-x86_64-cuda-9.2 only - "conv_4", "g_09", - "unsorted_segment/unsorted_segment_mean_rank2", + //"unsorted_segment/unsorted_segment_mean_rank2", //2019/05/28 - JVM crash on ppc64le only - See issue 7657 "g_11", @@ -111,17 +111,22 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a // 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398 "zeros_like/rank2_float32_dtype_int.*", - // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450 - "betainc.*", - // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453 "roll/.*", // 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455 "matrix_band_part/.*", - // 05.12.2019 failing https://github.com/eclipse/deeplearning4j/issues/8507 - "resize_bicubic/int32.*" + // 12.20.2019 - https://github.com/eclipse/deeplearning4j/issues/8559 + "fused_batch_norm/.*", + + // AB 2020/01/04 - https://github.com/eclipse/deeplearning4j/issues/8592 + "emptyArrayTests/reshape/rank2_shape2-0_2-0--1", + + //AB 2020/01/07 - Known issues + "bitcast/from_float64_to_int64", + "bitcast/from_rank2_float64_to_int64", + "bitcast/from_float64_to_uint64" }; /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java index d08fb5148..9b823cde7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java @@ -60,9 +60,6 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we //2019/07/22 - Result value failure "xlnet_cased_L-24_H-1024_A-16", - // 2019/07/22 - OOM, Passes with sufficient memory (16GB heap, 32GB off-heap tested) - "compression_residual_gru", - // 2019/07/22 - OOM, Passes with sufficient memory (16GB heap, 32GB off-heap tested) "deeplabv3_xception_ade20k_train", @@ -72,15 +69,9 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we // Missing Multinormal op, see https://github.com/eclipse/deeplearning4j/issues/7913 "gpt-2_117M", - //2019/05/15 - "Invalid shape for op shape_of: shape has invalid values <= 0: shape=[0]" - //Also: https://github.com/deeplearning4j/deeplearning4j/issues/7112 + //AB 2020/01/08, all 3 - https://github.com/eclipse/deeplearning4j/issues/8603 "ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync_2018_07_03", - - //2019/05/15 - CUSTOM CONV2D OP: rank of input array must be equal to 4, but got 0 instead ! - //Also: https://github.com/deeplearning4j/deeplearning4j/issues/7112 "ssd_mobilenet_v1_coco_2018_01_28", - - //2019/06/24 - size op dtypes / libnd4j size op issue: https://github.com/eclipse/deeplearning4j/issues/7938 "faster_rcnn_resnet101_coco_2018_01_28", //2019/06/24 - JVM crash on linux-x86_64-cpu-avx2 and -avx512 CI machines only - runs fine elsewhere @@ -256,7 +247,9 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we OpValidationSuite.ignoreFailing(); } -// if(!modelName.startsWith("mobilenet_v2_1.0_224")){ +// if(!modelName.startsWith("ssd_mobilenet_v1_coco_2018_01_28")){ +// if(!modelName.startsWith("ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync_2018_07_03")){ +// if(!modelName.startsWith("faster_rcnn_resnet101_coco_2018_01_28")){ // OpValidationSuite.ignoreFailing(); // } currentTestDir = testDir.newFolder(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index 22b8b4492..2a0880906 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -252,7 +252,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { System.out.println(Arrays.toString(shape)); // this is NHWC weights. will be changed soon. - assertArrayEquals(new int[]{5,5,1,32}, shape); + assertArrayEquals(new long[]{5,5,1,32}, shape); System.out.println(convNode); } @@ -520,7 +520,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { assertEquals(2, in1.first()); assertEquals(0, in1.second()); - System.out.println(tg.summary()); +// System.out.println(tg.summary()); + tg.summary(); int dimensionsLength = nodeSum.dimensionsLength(); assertEquals(1, dimensionsLength); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java index c3c94e1ed..986103ef6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java @@ -17,14 +17,17 @@ package org.nd4j.linalg; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.Pointer; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; +import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import org.nd4j.BaseND4JTest; import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -39,20 +42,16 @@ import org.slf4j.LoggerFactory; import java.lang.management.ManagementFactory; import java.util.*; +import static org.junit.Assume.assumeTrue; + /** * Base Nd4j test * @author Adam Gibson */ @RunWith(Parameterized.class) -public abstract class BaseNd4jTest { - private static Logger log = LoggerFactory.getLogger(BaseNd4jTest.class); - - @Rule - public TestName testName = new TestName(); - - protected long startTime; - protected int threadCountBefore; +@Slf4j +public abstract class BaseNd4jTest extends BaseND4JTest { protected Nd4jBackend backend; protected String name; @@ -69,15 +68,10 @@ public abstract class BaseNd4jTest { public BaseNd4jTest(String name, Nd4jBackend backend) { this.backend = backend; this.name = name; - - //Suppress ND4J initialization - don't need this logged for every test... - System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false"); - System.gc(); } public BaseNd4jTest(Nd4jBackend backend) { this(backend.getClass().getName() + UUID.randomUUID().toString(), backend); - } private static List backends; @@ -92,79 +86,6 @@ public abstract class BaseNd4jTest { if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) || backendsToRun.isEmpty()) backends.add(backend); } - - } - public static void assertArrayEquals(String string, Object[] expecteds, Object[] actuals) { - org.junit.Assert.assertArrayEquals(string, expecteds, actuals); - } - - public static void assertArrayEquals(Object[] expecteds, Object[] actuals) { - org.junit.Assert.assertArrayEquals(expecteds, actuals); - } - - public static void assertArrayEquals(String string, long[] shapeA, long[] shapeB) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB); - } - - public static void assertArrayEquals(String string, byte[] shapeA, byte[] shapeB) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB); - } - - public static void assertArrayEquals(byte[] shapeA, byte[] shapeB) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB); - } - - public static void assertArrayEquals(long[] shapeA, long[] shapeB) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB); - } - - public static void assertArrayEquals(String string, int[] shapeA, long[] shapeB) { - org.junit.Assert.assertArrayEquals(string, ArrayUtil.toLongArray(shapeA), shapeB); - } - - public static void assertArrayEquals(int[] shapeA, long[] shapeB) { - org.junit.Assert.assertArrayEquals(ArrayUtil.toLongArray(shapeA), shapeB); - } - - public static void assertArrayEquals(String string, long[] shapeA, int[] shapeB) { - org.junit.Assert.assertArrayEquals(string, shapeA, ArrayUtil.toLongArray(shapeB)); - } - - public static void assertArrayEquals(long[] shapeA, int[] shapeB) { - org.junit.Assert.assertArrayEquals(shapeA, ArrayUtil.toLongArray(shapeB)); - } - - public static void assertArrayEquals(String string, int[] shapeA, int[] shapeB) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB); - } - - public static void assertArrayEquals(int[] shapeA, int[] shapeB) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB); - } - - public static void assertArrayEquals(String string, boolean[] shapeA, boolean[] shapeB) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB); - } - - public static void assertArrayEquals(boolean[] shapeA, boolean[] shapeB) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB); - } - - - public static void assertArrayEquals(float[] shapeA, float[] shapeB, float delta) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta); - } - - public static void assertArrayEquals(double[] shapeA, double[] shapeB, double delta) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta); - } - - public static void assertArrayEquals(String string, float[] shapeA, float[] shapeB, float delta) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta); - } - - public static void assertArrayEquals(String string, double[] shapeA, double[] shapeB, double delta) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta); } @Parameterized.Parameters(name = "{index}: backend({0})={1}") @@ -175,6 +96,13 @@ public abstract class BaseNd4jTest { return ret; } + @Override + @Before + public void beforeTest(){ + super.beforeTest(); + Nd4j.factory().setOrder(ordering()); + } + /** * Get the default backend (jblas) * The default backend can be overridden by also passing: @@ -195,106 +123,6 @@ public abstract class BaseNd4jTest { } - - @Before - public void before() throws Exception { - // - log.info("Running {}.{} on {}", getClass().getName(), testName.getMethodName(), backend.getClass().getSimpleName()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j nd4j = new Nd4j(); - nd4j.initWithBackend(backend); - Nd4j.factory().setOrder(ordering()); - NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); - Nd4j.getExecutioner().enableDebugMode(false); - Nd4j.getExecutioner().enableVerboseMode(false); - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void after() throws Exception { - long totalTime = System.currentTimeMillis() - startTime; - Nd4j.getMemoryManager().purgeCaches(); - - logTestCompletion(totalTime); - if (System.getProperties().getProperty("backends") != null - && !System.getProperty("backends").contains(backend.getClass().getName())) - return; - Nd4j nd4j = new Nd4j(); - nd4j.initWithBackend(backend); - Nd4j.factory().setOrder(ordering()); - NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); - Nd4j.getExecutioner().enableDebugMode(false); - Nd4j.getExecutioner().enableVerboseMode(false); - - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - val currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - } - - public void logTestCompletion( long totalTime){ - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - sb.append(getClass().getSimpleName()).append(".").append(testName.getMethodName()) - .append(": ").append(totalTime).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } - - /** * The ordering for this test * This test will only be invoked for @@ -303,15 +131,10 @@ public abstract class BaseNd4jTest { * @return the ordering for this test */ public char ordering() { - return 'a'; + return 'c'; } - - public String getFailureMessage() { return "Failed with backend " + backend.getClass().getName() + " and ordering " + ordering(); } - - - } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java index 577e19ecb..fd1e14423 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java @@ -19,7 +19,6 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang3.RandomUtils; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -27,6 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.dataset.DataSet; @@ -57,9 +57,9 @@ public class LoneTest extends BaseNd4jTest { @Test public void testSoftmaxStability() { INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose(); - System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); +// System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1); - System.out.println("Element wise stride of output " + output.elementWiseStride()); +// System.out.println("Element wise stride of output " + output.elementWiseStride()); Nd4j.getExecutioner().exec(new SoftMax(input, output)); } @@ -85,11 +85,13 @@ public class LoneTest extends BaseNd4jTest { first = first.get(NDArrayIndex.interval(4, 8), NDArrayIndex.interval(0, 2, 8)); for (int i = 0; i < first.tensorsAlongDimension(0); i++) { - System.out.println(first.tensorAlongDimension(i, 0)); +// System.out.println(first.tensorAlongDimension(i, 0)); + first.tensorAlongDimension(i, 0); } for (int i = 0; i < first.tensorsAlongDimension(1); i++) { - System.out.println(first.tensorAlongDimension(i, 1)); +// System.out.println(first.tensorAlongDimension(i, 1)); + first.tensorAlongDimension(i, 1); } second = second.get(NDArrayIndex.interval(3, 7), NDArrayIndex.all()); third = third.permute(0, 2, 1); @@ -115,7 +117,7 @@ public class LoneTest extends BaseNd4jTest { assertEquals(i + 1,rowVector.getColumn(i).getInt(0)); assertEquals(i + 1,rowVector.get(NDArrayIndex.point(0), NDArrayIndex.interval(i, j)).getInt(0)); assertEquals(i + 1,colVector.get(NDArrayIndex.interval(i, j), NDArrayIndex.point(0)).getInt(0)); - System.out.println("Making sure index interval will not crash with begin/end vals..."); +// System.out.println("Making sure index interval will not crash with begin/end vals..."); jj = colVector.get(NDArrayIndex.interval(i, i + 1)); jj = colVector.get(NDArrayIndex.interval(i, i + 1)); } @@ -164,20 +166,9 @@ public class LoneTest extends BaseNd4jTest { INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4); INDArray b = Nd4j.getExecutioner().exec(new Tanh(aA)); //Nd4j.getExecutioner().execAndReturn(new TanhDerivative(aD)); - System.out.println(aA); - System.out.println(aD); - System.out.println(b); - } - - @Test(expected = IllegalStateException.class) - @Ignore // test is outdated - public void opsNotAllowed() { - INDArray A = Nd4j.ones(2, 3, 1); - INDArray B = Nd4j.ones(2, 3); - - System.out.println(A.add(B)); - System.out.println(B.add(A)); - +// System.out.println(aA); +// System.out.println(aD); +// System.out.println(b); } @Test @@ -191,14 +182,26 @@ public class LoneTest extends BaseNd4jTest { max = 64; A = Nd4j.linspace(1, max, max).reshape(1, max); currentArgMax = Nd4j.argMax(A).getInt(0); - System.out.println("Returned argMax is " + currentArgMax); +// System.out.println("Returned argMax is " + currentArgMax); assertEquals(max - 1, currentArgMax); } + @Test + public void testRPF() { + val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3); + + log.info("--------"); + + val tad = array.tensorAlongDimension(1, 1, 2); + Nd4j.exec(new PrintVariable(tad, false)); + log.info("TAD native shapeInfo: {}", tad.shapeInfoDataBuffer().asLong()); + log.info("TAD Java shapeInfo: {}", tad.shapeInfoJava()); + log.info("TAD:\n{}", tad); + } @Test public void testConcat3D_Vstack_C() { - val shape = new long[]{1, 1000, 150}; + val shape = new long[]{1, 1000, 20}; List cArrays = new ArrayList<>(); List fArrays = new ArrayList<>(); @@ -211,15 +214,17 @@ public class LoneTest extends BaseNd4jTest { Nd4j.getExecutioner().commit(); - long time1 = System.currentTimeMillis(); - INDArray res = Nd4j.vstack(cArrays); - long time2 = System.currentTimeMillis(); + val time1 = System.currentTimeMillis(); + val res = Nd4j.vstack(cArrays); + val time2 = System.currentTimeMillis(); - log.info("Time spent: {} ms", time2 - time1); +// log.info("Time spent: {} ms", time2 - time1); for (int e = 0; e < 32; e++) { - INDArray tad = res.tensorAlongDimension(e, 1, 2); + val tad = res.tensorAlongDimension(e, 1, 2); + assertEquals("Failed for TAD [" + e + "]",(double) e, tad.meanNumber().doubleValue(), 1e-5); + assertEquals((double) e, tad.getDouble(0), 1e-5); } } @@ -248,7 +253,7 @@ public class LoneTest extends BaseNd4jTest { Collections.sort(times); - log.info("p50: {}; avg: {};", times.get(times.size() / 2), time); +// log.info("p50: {}; avg: {};", times.get(times.size() / 2), time); } @Test(expected = Exception.class) @@ -270,25 +275,30 @@ public class LoneTest extends BaseNd4jTest { */ int[] ranksToCheck = new int[]{2, 3, 4, 5}; for (int rank = 0; rank < ranksToCheck.length; rank++) { - log.info("\nRunning through rank " + ranksToCheck[rank]); +// log.info("\nRunning through rank " + ranksToCheck[rank]); List> allF = NDArrayCreationUtil.getTestMatricesWithVaryingShapes(ranksToCheck[rank], 'f', DataType.FLOAT); Iterator> iter = allF.iterator(); while (iter.hasNext()) { Pair currentPair = iter.next(); INDArray origArrayF = currentPair.getFirst(); INDArray sameArrayC = origArrayF.dup('c'); - log.info("\nLooping through slices for shape " + currentPair.getSecond()); - log.info("\nOriginal array:\n" + origArrayF); +// log.info("\nLooping through slices for shape " + currentPair.getSecond()); +// log.info("\nOriginal array:\n" + origArrayF); + origArrayF.toString(); INDArray viewF = origArrayF.slice(0); INDArray viewC = sameArrayC.slice(0); - log.info("\nSlice 0, C order:\n" + viewC.toString()); - log.info("\nSlice 0, F order:\n" + viewF.toString()); +// log.info("\nSlice 0, C order:\n" + viewC.toString()); +// log.info("\nSlice 0, F order:\n" + viewF.toString()); + viewC.toString(); + viewF.toString(); for (int i = 0; i < viewF.slices(); i++) { //assertEquals(viewF.slice(i),viewC.slice(i)); for (int j = 0; j < viewF.slice(i).length(); j++) { //if (j>0) break; - log.info("\nC order slice " + i + ", element 0 :" + viewC.slice(i).getDouble(j)); //C order is fine - log.info("\nF order slice " + i + ", element 0 :" + viewF.slice(i).getDouble(j)); //throws index out of bound err on F order +// log.info("\nC order slice " + i + ", element 0 :" + viewC.slice(i).getDouble(j)); //C order is fine +// log.info("\nF order slice " + i + ", element 0 :" + viewF.slice(i).getDouble(j)); //throws index out of bound err on F order + viewC.slice(i).getDouble(j); + viewF.slice(i).getDouble(j); } } } @@ -300,17 +310,21 @@ public class LoneTest extends BaseNd4jTest { INDArray arr = Nd4j.create(1, 3); INDArray reshaped = arr.reshape('f', 3, 1); for (int i=0;i pair : testInputs) { String msg = pair.getSecond(); INDArray in = pair.getFirst(); - System.out.println("Count " + count); +// System.out.println("Count " + count); INDArray dup = in.dup(); INDArray dupc = in.dup('c'); INDArray dupf = in.dup('f'); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 68551e53d..ce1ad388c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -66,7 +66,9 @@ import org.nd4j.linalg.api.ops.impl.reduce.same.Sum; import org.nd4j.linalg.api.ops.impl.reduce3.*; import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans; +import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; +import org.nd4j.linalg.api.ops.impl.shape.Reshape; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; @@ -81,6 +83,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy; import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; @@ -129,10 +132,14 @@ public class Nd4jTestsC extends BaseNd4jTest { l1 = Nd4j.getBlasWrapper().level1(); } + @Override + public long getTimeoutMilliseconds() { + return 90000; + } @Before public void before() throws Exception { - super.before(); + super.beforeTest(); Nd4j.setDataType(DataType.DOUBLE); Nd4j.getRandom().setSeed(123); Nd4j.getExecutioner().enableDebugMode(false); @@ -141,7 +148,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @After public void after() throws Exception { - super.after(); + super.afterTest(); Nd4j.setDataType(initialType); } @@ -238,8 +245,8 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray inDup = in.dup(); - System.out.println(in); - System.out.println(inDup); +// System.out.println(in); +// System.out.println(inDup); assertEquals(arr, in); //Passes: Original array "in" is OK, but array "inDup" is not!? assertEquals(in, inDup); //Fails @@ -433,6 +440,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @Ignore public void testMmulOp() throws Exception { INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); INDArray z = Nd4j.create(2, 2); @@ -451,7 +459,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testSubiRowVector() { INDArray oneThroughFour = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape('c', 2, 2); - INDArray row1 = oneThroughFour.getRow(1); + INDArray row1 = oneThroughFour.getRow(1).dup(); oneThroughFour.subiRowVector(row1); INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2}); assertEquals(getFailureMessage(), result, oneThroughFour); @@ -576,7 +584,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray outAsc = Nd4j.sortRows(in, i, true); INDArray outDesc = Nd4j.sortRows(in, i, false); - System.out.println("outDesc: " + Arrays.toString(outAsc.data().asFloat())); +// System.out.println("outDesc: " + Arrays.toString(outAsc.data().asFloat())); for (int j = 0; j < nRows; j++) { assertEquals(outAsc.getDouble(j, i), j, 1e-1); int origRowIdxAsc = order.indexOf(j); @@ -810,10 +818,10 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray outc = Transforms.isMax(orig.dup('c')); assertEquals(exp, outc); - log.info("Orig: {}", orig.dup('f').data().asFloat()); +// log.info("Orig: {}", orig.dup('f').data().asFloat()); INDArray outf = Transforms.isMax(orig.dup('f'), orig.dup('f').ulike()); - log.info("OutF: {}", outf.data().asFloat()); +// log.info("OutF: {}", outf.data().asFloat()); assertEquals(exp, outf); } @@ -872,7 +880,7 @@ public class Nd4jTestsC extends BaseNd4jTest { //1d: col vector - System.out.println("----------------------------------"); +// System.out.println("----------------------------------"); INDArray col = Nd4j.create(new double[] {1, 2, 3, 1}, new long[] {4, 1}); INDArray alongDim0col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()), 0))[0]; INDArray alongDim1col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()),1))[0]; @@ -908,7 +916,7 @@ public class Nd4jTestsC extends BaseNd4jTest { //Along dim 1: //[0 0 1] //[0 1 0] - System.out.println("---------------------"); +// System.out.println("---------------------"); INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}}); INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0))[0]; INDArray alongDim0f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 0))[0]; @@ -931,7 +939,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray result = Nd4j.argMax(orig2d.dup('c'), 0); - System.out.println("IMAx result: " + result); +// System.out.println("IMAx result: " + result); } @Test @@ -940,10 +948,10 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0))[0]; INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}}); - System.out.println("Original shapeInfo: " + orig2d.dup('c').shapeInfoDataBuffer()); +// System.out.println("Original shapeInfo: " + orig2d.dup('c').shapeInfoDataBuffer()); - System.out.println("Expected: " + Arrays.toString(expAlong0_2d.data().asFloat())); - System.out.println("Actual: " + Arrays.toString(alongDim0c_2d.data().asFloat())); +// System.out.println("Expected: " + Arrays.toString(expAlong0_2d.data().asFloat())); +// System.out.println("Actual: " + Arrays.toString(alongDim0c_2d.data().asFloat())); assertEquals(expAlong0_2d, alongDim0c_2d); } @@ -953,7 +961,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray bias = Nd4j.create(1, 3); BroadcastOp op = new BroadcastAddOp(z, bias, z, 3); Nd4j.getExecutioner().exec(op); - System.out.println("First: OK"); +// System.out.println("First: OK"); //OK at this point: executes successfully @@ -961,7 +969,7 @@ public class Nd4jTestsC extends BaseNd4jTest { bias = Nd4j.create(1, 3); op = new BroadcastAddOp(z, bias, z, 3); Nd4j.getExecutioner().exec(op); //Crashing here, when we are doing exactly the same thing as before... - System.out.println("Second: OK"); +// System.out.println("Second: OK"); } @@ -970,19 +978,19 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray expected = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); for (char order : new char[] {'c', 'f'}) { - System.out.println(order); +// System.out.println(order); INDArray arr1 = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape( 2, 3).dup('c'); INDArray arr2 = Nd4j.linspace(7, 9, 3, DataType.DOUBLE).reshape(1, 3).dup('c'); Nd4j.factory().setOrder(order); - log.info("arr1: {}", arr1.data()); - log.info("arr2: {}", arr2.data()); +// log.info("arr1: {}", arr1.data()); +// log.info("arr2: {}", arr2.data()); INDArray merged = Nd4j.vstack(arr1, arr2); - System.out.println(merged.data()); - System.out.println(expected); +// System.out.println(merged.data()); +// System.out.println(expected); assertEquals("Failed for [" + order + "] order", expected, merged); } @@ -1005,8 +1013,8 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray expAllZeros = Nd4j.getExecutioner().exec(new Eps(first, second, Nd4j.create(DataType.BOOL, 10))); INDArray expAllOnes = Nd4j.getExecutioner().exec(new Eps(first, first, Nd4j.create(DataType.BOOL, 10))); - System.out.println(expAllZeros); - System.out.println(expAllOnes); +// System.out.println(expAllZeros); +// System.out.println(expAllOnes); val allones = Nd4j.getExecutioner().exec(new All(expAllOnes)).getDouble(0); @@ -1052,7 +1060,7 @@ public class Nd4jTestsC extends BaseNd4jTest { }*/ for (val shape : shapes) { for (int[] dims : sumDims) { - System.out.println("Shape: " + Arrays.toString(shape) + ", sumDims=" + Arrays.toString(dims)); +// System.out.println("Shape: " + Arrays.toString(shape) + ", sumDims=" + Arrays.toString(dims)); int length = ArrayUtil.prod(shape); INDArray inC = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape); INDArray inF = inC.dup('f'); @@ -1085,8 +1093,8 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int alongDimension = 0; alongDimension < rank; alongDimension++) { - System.out.println("Testing rank " + rank + " along dimension " + alongDimension + ", (shape=" - + Arrays.toString(shape) + ")"); +// System.out.println("Testing rank " + rank + " along dimension " + alongDimension + ", (shape=" +// + Arrays.toString(shape) + ")"); INDArray arrC = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape); INDArray arrF = arrC.dup('f'); val resC = Nd4j.getExecutioner().exec(new IsMax(arrC, alongDimension))[0]; @@ -1265,14 +1273,14 @@ public class Nd4jTestsC extends BaseNd4jTest { val dims = new int[][] {{0, 1}, {1, 0}, {0, 2}, {2, 0}, {1, 2}, {2, 1}}; double[][] exp = new double[][] {{16, 20}, {16, 20}, {14, 22}, {14, 22}, {10, 26}, {10, 26}}; - System.out.println("dims\texpected\t\tactual"); +// System.out.println("dims\texpected\t\tactual"); for (int i = 0; i < dims.length; i++) { val d = dims[i]; double[] e = exp[i]; INDArray out = in.sum(d); - System.out.println(Arrays.toString(d) + "\t" + Arrays.toString(e) + "\t" + out); +// System.out.println(Arrays.toString(d) + "\t" + Arrays.toString(e) + "\t" + out); assertEquals(Nd4j.create(e, out.shape()), out); } } @@ -1299,7 +1307,7 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, outC); assertEquals(exp, outF); - System.out.println(Arrays.toString(d) + "\t" + outC + "\t" + outF); +// System.out.println(Arrays.toString(d) + "\t" + outC + "\t" + outF); } } @@ -1335,7 +1343,7 @@ public class Nd4jTestsC extends BaseNd4jTest { zC.setData(Nd4j.linspace(1, 24, 24, DataType.DOUBLE).data()); for (int tad = 0; tad < zC.tensorsAlongDimension(dim); tad++) { INDArray javaTad = zC.tensorAlongDimension(tad, dim); - System.out.println("Tad " + tad + " is " + zC.tensorAlongDimension(tad, dim)); +// System.out.println("Tad " + tad + " is " + zC.tensorAlongDimension(tad, dim)); } INDArray zF = Nd4j.create(shape, 'f'); @@ -1347,10 +1355,10 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray exp = Nd4j.create(expLinspaced[i], shape, 'c'); INDArray expF = Nd4j.create(shape, 'f'); expF.assign(exp); - for (int tad = 0; tad < zC.tensorsAlongDimension(dim); tad++) { - System.out.println(zC.tensorAlongDimension(tad, dim).offset() + " and f offset is " - + zF.tensorAlongDimension(tad, dim).offset()); - } +// for (int tad = 0; tad < zC.tensorsAlongDimension(dim); tad++) { +// System.out.println(zC.tensorAlongDimension(tad, dim).offset() + " and f offset is " +// + zF.tensorAlongDimension(tad, dim).offset()); +// } Nd4j.getExecutioner().exec(opc); Nd4j.getExecutioner().exec(opf); @@ -2086,7 +2094,8 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < 6; i++) { //This should fail for i >= 2, but doesn't - System.out.println(arr.size(i)); +// System.out.println(arr.size(i)); + arr.size(i); } } @@ -2100,7 +2109,7 @@ public class Nd4jTestsC extends BaseNd4jTest { allocate.asFloatBuffer().put(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); DataBuffer buff = Nd4j.createBuffer(allocate, DataType.FLOAT, 10); float sum = Nd4j.create(buff).sumNumber().floatValue(); - System.out.println(sum); +// System.out.println(sum); assertEquals(55f, sum, 0.001f); Nd4j.setDataType(initialType); @@ -2112,7 +2121,7 @@ public class Nd4jTestsC extends BaseNd4jTest { val res = Nd4j.create(DataType.BOOL, 5); Nd4j.getExecutioner().exec(new Eps(ones, ones, res)); - log.info("Result: {}", res); +// log.info("Result: {}", res); assertTrue(res.all()); } @@ -2125,8 +2134,8 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray expAllZeros1 = Nd4j.getExecutioner().exec(new Eps(first, second, Nd4j.create(DataType.BOOL, new long[] {1, 10}, 'f'))); INDArray expAllZeros2 = Nd4j.getExecutioner().exec(new Eps(second, first, Nd4j.create(DataType.BOOL, new long[] {1, 10}, 'f'))); - System.out.println(expAllZeros1); - System.out.println(expAllZeros2); +// System.out.println(expAllZeros1); +// System.out.println(expAllZeros2); assertTrue(expAllZeros1.none()); assertTrue(expAllZeros2.none()); @@ -2168,7 +2177,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray assertionRepeat = Nd4j.create(new double[][] {{1, 1, 2, 2}, {3, 3, 4, 4}}); assertArrayEquals(new long[] {2, 4}, assertionRepeat.shape()); assertEquals(assertionRepeat, repeatAlongDimension); - System.out.println(repeatAlongDimension); +// System.out.println(repeatAlongDimension); INDArray ret = Nd4j.create(new double[] {0, 1, 2}).reshape(1, 3); INDArray tile = Nd4j.tile(ret, 2, 2); INDArray assertion = Nd4j.create(new double[][] {{0, 1, 2, 0, 1, 2}, {0, 1, 2, 0, 1, 2}}); @@ -2599,7 +2608,7 @@ public class Nd4jTestsC extends BaseNd4jTest { // vec = vec.dup('c'); // vec = vec.dup('f'); - System.out.println("Vec: " + vec); +// System.out.println("Vec: " + vec); INDArray outC = arrC.muliRowVector(vec); INDArray outF = arrF.muliRowVector(vec); @@ -2639,7 +2648,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double[][] ind = {{5.1, 3.5, 1.4}, {4.9, 3.0, 1.4}, {4.7, 3.2, 1.3}}; INDArray in = Nd4j.create(ind); INDArray stdev = in.std(1); - log.info("StdDev: {}", stdev.toDoubleVector()); +// log.info("StdDev: {}", stdev.toDoubleVector()); INDArray exp = Nd4j.create(new double[] {1.8556220879622372, 1.7521415467935233, 1.7039170558842744}); assertEquals(exp, stdev); } @@ -2856,7 +2865,6 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(25, Nd4j.getBlasWrapper().dot(row, row), 1e-1); } - @Test public void testIdentity() { INDArray eye = Nd4j.eye(5); @@ -2869,10 +2877,10 @@ public class Nd4jTestsC extends BaseNd4jTest { public void testTemp() { Nd4j.getRandom().setSeed(12345); INDArray in = Nd4j.rand(new long[] {2, 2, 2}); - System.out.println("In:\n" + in); +// System.out.println("In:\n" + in); INDArray permuted = in.permute(0, 2, 1); //Permute, so we get correct order after reshaping INDArray out = permuted.reshape(4, 2); - System.out.println("Out:\n" + out); +// System.out.println("Out:\n" + out); int countZero = 0; for (int i = 0; i < 8; i++) @@ -2923,7 +2931,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray columnConcat = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray concatWith = Nd4j.zeros(2, 3); INDArray columnWiseConcat = Nd4j.concat(0, columnConcat, concatWith); - System.out.println(columnConcat); +// System.out.println(columnConcat); } @@ -2955,9 +2963,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testSoftmaxStability() { INDArray input = Nd4j.create(new double[] {-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose(); - System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); +// System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); INDArray output = Nd4j.create(10, 1); - System.out.println("Element wise stride of output " + output.elementWiseStride()); +// System.out.println("Element wise stride of output " + output.elementWiseStride()); Nd4j.getExecutioner().exec(new SoftMax(input, output)); } @@ -3174,7 +3182,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < 20; i++) { INDArray arr1 = Nd4j.zeros(1, 100); Nd4j.getExecutioner().execAndReturn(new SoftMax(arr1)); - System.out.println(Arrays.toString(arr1.data().asFloat())); +// System.out.println(Arrays.toString(arr1.data().asFloat())); } } @@ -3189,8 +3197,8 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray out = Nd4j.getExecutioner().exec(new LeakyReLU(arr, 0.01)); - System.out.println("Expected: " + Arrays.toString(expected)); - System.out.println("Actual: " + Arrays.toString(out.data().asDouble())); +// System.out.println("Expected: " + Arrays.toString(expected)); +// System.out.println("Actual: " + Arrays.toString(out.data().asDouble())); INDArray exp = Nd4j.create(expected); assertEquals(exp, out); @@ -3309,19 +3317,19 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray arr2c = Nd4j.create(shape2, 'c'); INDArray arr2f = Nd4j.create(shape2, 'f'); - log.info("2f data: {}", Arrays.toString(arr2f.data().asFloat())); +// log.info("2f data: {}", Arrays.toString(arr2f.data().asFloat())); arr2c.assign(arr); - System.out.println("--------------"); +// System.out.println("--------------"); arr2f.assign(arr); INDArray exp = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape2); - log.info("arr data: {}", Arrays.toString(arr.data().asFloat())); - log.info("2c data: {}", Arrays.toString(arr2c.data().asFloat())); - log.info("2f data: {}", Arrays.toString(arr2f.data().asFloat())); - log.info("2c shape: {}", Arrays.toString(arr2c.shapeInfoDataBuffer().asInt())); - log.info("2f shape: {}", Arrays.toString(arr2f.shapeInfoDataBuffer().asInt())); +// log.info("arr data: {}", Arrays.toString(arr.data().asFloat())); +// log.info("2c data: {}", Arrays.toString(arr2c.data().asFloat())); +// log.info("2f data: {}", Arrays.toString(arr2f.data().asFloat())); +// log.info("2c shape: {}", Arrays.toString(arr2c.shapeInfoDataBuffer().asInt())); +// log.info("2f shape: {}", Arrays.toString(arr2f.shapeInfoDataBuffer().asInt())); assertEquals(exp, arr2c); assertEquals(exp, arr2f); } @@ -3335,21 +3343,21 @@ public class Nd4jTestsC extends BaseNd4jTest { 56.0, 68.0, 80.0, 92.0, 9.0, 21.0, 33.0, 45.0, 57.0, 69.0, 81.0, 93.0, 10.0, 22.0, 34.0, 46.0, 58.0, 70.0, 82.0, 94.0, 11.0, 23.0, 35.0, 47.0, 59.0, 71.0, 83.0, 95.0, 12.0, 24.0, 36.0, 48.0, 60.0, 72.0, 84.0, 96.0}, new long[] {12, 8}, 'f'); - log.info("arr2f shape: {}", Arrays.toString(arr2f.shapeInfoDataBuffer().asInt())); - log.info("arr2f data: {}", Arrays.toString(arr2f.data().asFloat())); - log.info("render: {}", arr2f); +// log.info("arr2f shape: {}", Arrays.toString(arr2f.shapeInfoDataBuffer().asInt())); +// log.info("arr2f data: {}", Arrays.toString(arr2f.data().asFloat())); +// log.info("render: {}", arr2f); - log.info("----------------------"); +// log.info("----------------------"); INDArray array = Nd4j.linspace(1, 96, 96, DataType.DOUBLE).reshape('c', 12, 8); - log.info("array render: {}", array); +// log.info("array render: {}", array); - log.info("----------------------"); +// log.info("----------------------"); INDArray arrayf = array.dup('f'); - log.info("arrayf render: {}", arrayf); - log.info("arrayf shape: {}", Arrays.toString(arrayf.shapeInfoDataBuffer().asInt())); - log.info("arrayf data: {}", Arrays.toString(arrayf.data().asFloat())); +// log.info("arrayf render: {}", arrayf); +// log.info("arrayf shape: {}", Arrays.toString(arrayf.shapeInfoDataBuffer().asInt())); +// log.info("arrayf data: {}", Arrays.toString(arrayf.data().asFloat())); } @Test @@ -3385,7 +3393,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray arr2f = arr.dup('f'); arr2c.addi(arr); - System.out.println("--------------"); +// System.out.println("--------------"); arr2f.addi(arr); INDArray exp = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape2).mul(2.0); @@ -3393,8 +3401,8 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, arr2c); assertEquals(exp, arr2f); - log.info("2c data: {}", Arrays.toString(arr2c.data().asFloat())); - log.info("2f data: {}", Arrays.toString(arr2f.data().asFloat())); +// log.info("2c data: {}", Arrays.toString(arr2c.data().asFloat())); +// log.info("2f data: {}", Arrays.toString(arr2f.data().asFloat())); assertTrue(arrayNotEquals(arr2c.data().asFloat(), arr2f.data().asFloat(), 1e-5f)); } @@ -3410,7 +3418,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray arr2f = arr.dup('f'); arr2c.addi(arr); - System.out.println("--------------"); +// System.out.println("--------------"); arr2f.addi(arr); INDArray exp = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape2).dup('f').mul(2.0); @@ -3418,8 +3426,8 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, arr2c); assertEquals(exp, arr2f); - log.info("2c data: {}", Arrays.toString(arr2c.data().asFloat())); - log.info("2f data: {}", Arrays.toString(arr2f.data().asFloat())); +// log.info("2c data: {}", Arrays.toString(arr2c.data().asFloat())); +// log.info("2f data: {}", Arrays.toString(arr2f.data().asFloat())); assertTrue(arrayNotEquals(arr2c.data().asFloat(), arr2f.data().asFloat(), 1e-5f)); } @@ -3435,7 +3443,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray arr2f = Nd4j.create(shape2, 'f'); arr2c.assign(arr); - System.out.println("--------------"); +// System.out.println("--------------"); arr2f.assign(arr); INDArray exp = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape2); @@ -3465,8 +3473,8 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray exp = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape2); - System.out.println("Zf data: " + Arrays.toString(z_f.data().asFloat())); - System.out.println("Zc data: " + Arrays.toString(z_c.data().asFloat())); +// System.out.println("Zf data: " + Arrays.toString(z_f.data().asFloat())); +// System.out.println("Zc data: " + Arrays.toString(z_c.data().asFloat())); assertEquals(exp, z_f); assertEquals(exp, z_c); @@ -3527,37 +3535,45 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testVarConst() { INDArray x = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10); - System.out.println(x); +// System.out.println(x); assertFalse(Double.isNaN(x.var(0).sumNumber().doubleValue())); - System.out.println(x.var(0)); +// System.out.println(x.var(0)); + x.var(0); assertFalse(Double.isNaN(x.var(1).sumNumber().doubleValue())); - System.out.println(x.var(1)); +// System.out.println(x.var(1)); + x.var(1); - System.out.println("================================="); +// System.out.println("================================="); // 2d array - all elements are the same INDArray a = Nd4j.ones(10, 10).mul(10); - System.out.println(a); +// System.out.println(a); assertFalse(Double.isNaN(a.var(0).sumNumber().doubleValue())); - System.out.println(a.var(0)); +// System.out.println(a.var(0)); + a.var(0); assertFalse(Double.isNaN(a.var(1).sumNumber().doubleValue())); - System.out.println(a.var(1)); +// System.out.println(a.var(1)); + a.var(1); // 2d array - constant in one dimension - System.out.println("================================="); +// System.out.println("================================="); INDArray nums = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); INDArray b = Nd4j.ones(10, 10).mulRowVector(nums); - System.out.println(b); +// System.out.println(b); assertFalse(Double.isNaN((Double) b.var(0).sumNumber())); - System.out.println(b.var(0)); +// System.out.println(b.var(0)); + b.var(0); assertFalse(Double.isNaN((Double) b.var(1).sumNumber())); - System.out.println(b.var(1)); +// System.out.println(b.var(1)); + b.var(1); - System.out.println("================================="); - System.out.println(b.transpose()); +// System.out.println("================================="); +// System.out.println(b.transpose()); assertFalse(Double.isNaN((Double) b.transpose().var(0).sumNumber())); - System.out.println(b.transpose().var(0)); +// System.out.println(b.transpose().var(0)); + b.transpose().var(0); assertFalse(Double.isNaN((Double) b.transpose().var(1).sumNumber())); - System.out.println(b.transpose().var(1)); +// System.out.println(b.transpose().var(1)); + b.transpose().var(1); } @Test @@ -3618,8 +3634,8 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(5, result.columns()); assertEquals(assertion, result); - System.out.println(assertion.toString()); - System.out.println(result.toString()); +// System.out.println(assertion.toString()); +// System.out.println(result.toString()); } @@ -3651,7 +3667,7 @@ public class Nd4jTestsC extends BaseNd4jTest { Nd4j.getExecutioner().exec(new ReplaceNans(array, 0.0)); - System.out.println("Array After: " + array); +// System.out.println("Array After: " + array); assertEquals(assertion, array); } @@ -3754,8 +3770,8 @@ public class Nd4jTestsC extends BaseNd4jTest { IAMax iaMax = new IAMax(arr.dup()); val imax = Nd4j.getExecutioner().execAndReturn(iMax).getFinalResult().intValue(); val iamax = Nd4j.getExecutioner().execAndReturn(iaMax).getFinalResult().intValue(); - System.out.println("IMAX: " + imax); - System.out.println("IAMAX: " + iamax); +// System.out.println("IMAX: " + imax); +// System.out.println("IAMAX: " + iamax); assertEquals(1, iamax); assertEquals(3, imax); } @@ -3769,8 +3785,8 @@ public class Nd4jTestsC extends BaseNd4jTest { IMin iMin = new IMin(arr.dup()); double imin = Nd4j.getExecutioner().execAndReturn(iMin).getFinalResult().doubleValue(); double iamin = Nd4j.getExecutioner().execAndReturn(iaMin).getFinalResult().doubleValue(); - System.out.println("IMin: " + imin); - System.out.println("IAMin: " + iamin); +// System.out.println("IMin: " + imin); +// System.out.println("IAMin: " + iamin); assertEquals(3, iamin, 1e-12); assertEquals(1, imin, 1e-12); } @@ -3782,7 +3798,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (char orderArr : orders) { for (char orderbc : orders) { - System.out.println(orderArr + "\t" + orderbc); +// System.out.println(orderArr + "\t" + orderbc); INDArray arrOrig = Nd4j.ones(3, 4, 5).dup(orderArr); //Broadcast on dimensions 0,1 @@ -3830,7 +3846,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (char orderArr : orders) { for (char orderbc : orders) { - System.out.println(orderArr + "\t" + orderbc); +// System.out.println(orderArr + "\t" + orderbc); INDArray arrOrig = Nd4j.ones(3, 4, 5, 6).dup(orderArr); //Broadcast on dimensions 0,1 @@ -4270,7 +4286,7 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(10, matrix.rows()); assertEquals(6, matrix.columns()); - log.info("Result: {}", matrix); +// log.info("Result: {}", matrix); for (int x = 0; x < 10; x++) { assertEquals((double) x, matrix.getRow(x).meanNumber().doubleValue(), 0.1); @@ -4283,7 +4299,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray array = Nd4j.create(10, 3, 96, 96); for (int i = 0; i < 10; i++) { - log.info("Trying i: {}", i); +// log.info("Trying i: {}", i); array.tensorAlongDimension(i, 1, 2, 3).putScalar(1, 2, 3, 1); } } @@ -4367,7 +4383,7 @@ public class Nd4jTestsC extends BaseNd4jTest { // Nd4j.getExecutioner().commit(); val executioner = (GridExecutioner) Nd4j.getExecutioner(); - log.info("Starting: -------------------------------"); +// log.info("Starting: -------------------------------"); //log.info("Point A: [{}]", executioner.getQueueLength()); @@ -4475,16 +4491,16 @@ public class Nd4jTestsC extends BaseNd4jTest { } Nd4j.getExecutioner().commit(); - log.info("original: \n{}", initial); +// log.info("original: \n{}", initial); Nd4j.getExecutioner().exec(new BroadcastLessThan(initial, mask, result, 1)); Nd4j.getExecutioner().commit(); - log.info("Comparison ----------------------------------------------"); +// log.info("Comparison ----------------------------------------------"); for (int i = 0; i < initial.rows(); i++) { val row = result.getRow(i); assertEquals("Failed at row " + i, exp, row); - log.info("-------------------"); +// log.info("-------------------"); } } @@ -4569,7 +4585,7 @@ public class Nd4jTestsC extends BaseNd4jTest { val row = haystack.getRow(1); val drow = row.dup(); - log.info("row shape: {}", row.shapeInfoDataBuffer()); +// log.info("row shape: {}", row.shapeInfoDataBuffer()); assertEquals(needle, drow); } @@ -4583,7 +4599,7 @@ public class Nd4jTestsC extends BaseNd4jTest { -1.25485503673}); INDArray reduced = Nd4j.getExecutioner().exec(new CosineDistance(haystack, needle, 1)); - log.info("Reduced: {}", reduced); +// log.info("Reduced: {}", reduced); INDArray exp = Nd4j.create(new double[] {0.577452, 0.0, 1.80182}); @@ -4666,7 +4682,7 @@ public class Nd4jTestsC extends BaseNd4jTest { .doubleValue(); assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001); - log.info("Euclidean: {} vs {} is {}", x, needle, res); +// log.info("Euclidean: {} vs {} is {}", x, needle, res); } } @@ -4687,7 +4703,7 @@ public class Nd4jTestsC extends BaseNd4jTest { .doubleValue(); assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001); - log.info("Euclidean: {} vs {} is {}", x, needle, res); +// log.info("Euclidean: {} vs {} is {}", x, needle, res); } } @@ -4709,7 +4725,7 @@ public class Nd4jTestsC extends BaseNd4jTest { .doubleValue(); assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001); - log.info("Cosine: {} vs {} is {}", x, needle, res); +// log.info("Cosine: {} vs {} is {}", x, needle, res); } } @@ -4819,7 +4835,7 @@ public class Nd4jTestsC extends BaseNd4jTest { x.getRow(r).putScalar(p, 1); } - log.info("X: {}", x); +// log.info("X: {}", x); INDArray y = Nd4j.create(new double[] {0, 0, 0, 0, 1, 0}); INDArray res = Nd4j.getExecutioner().exec(new HammingDistance(x, y, 1)); @@ -5216,7 +5232,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - log.info("Array shapeInfo: {}", array.shapeInfoJava()); +// log.info("Array shapeInfo: {}", array.shapeInfoJava()); INDArray rev = Nd4j.reverse(array); @@ -5280,7 +5296,7 @@ public class Nd4jTestsC extends BaseNd4jTest { Nd4j.sort(matrix.getColumn(0), true); - log.info("Matrix: {}", matrix); +// log.info("Matrix: {}", matrix); assertEquals(exp, matrix.getColumn(0)); } @@ -5315,7 +5331,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testNativeSort3() { - INDArray array = Nd4j.linspace(1, 1048576, 1048576, DataType.DOUBLE).reshape(1, -1); + int length = isIntegrationTests() ? 1048576 : 16484; + INDArray array = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(1, -1); INDArray exp = array.dup(); Nd4j.shuffle(array, 0); @@ -5371,8 +5388,8 @@ public class Nd4jTestsC extends BaseNd4jTest { Transforms.reverse(array, false); - log.info("Reversed shapeInfo: {}", array.shapeInfoJava()); - log.info("Reversed: {}", array); +// log.info("Reversed shapeInfo: {}", array.shapeInfoJava()); +// log.info("Reversed: {}", array); Transforms.reverse(array, false); @@ -5389,7 +5406,7 @@ public class Nd4jTestsC extends BaseNd4jTest { val reversed = Transforms.reverse(array, true); - log.info("Reversed: {}", reversed); +// log.info("Reversed: {}", reversed); val rereversed = Transforms.reverse(reversed, true); @@ -5432,7 +5449,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray array = Nd4j.linspace(1, 2017152, 2017152, DataType.DOUBLE).reshape(1, -1); INDArray exp = array.dup(); Transforms.reverse(array, false); - log.info("Reverse: {}", array); +// log.info("Reverse: {}", array); long time1 = System.currentTimeMillis(); @@ -6062,6 +6079,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + //@Ignore public void testMatmul_128by256() { val mA = Nd4j.create(128, 156).assign(1.0f); val mB = Nd4j.create(156, 256).assign(1.0f); @@ -6080,6 +6098,42 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(mE, mC); } + /* + Analog of this TF code: + a = tf.constant([], shape=[0,1]) + b = tf.constant([], shape=[1, 0]) + c = tf.matmul(a, b) + */ + @Test + public void testMatmul_Empty() { + val mA = Nd4j.create(0,1); + val mB = Nd4j.create(1,0); + val mC = Nd4j.create(0,0); + + val op = DynamicCustomOp.builder("matmul") + .addInputs(mA, mB) + .addOutputs(mC) + .build(); + + Nd4j.getExecutioner().exec(op); + assertEquals(Nd4j.create(0,0), mC); + } + + @Test + public void testMatmul_Empty1() { + val mA = Nd4j.create(1,0, 4); + val mB = Nd4j.create(1,4, 0); + val mC = Nd4j.create(1,0, 0); + + val op = DynamicCustomOp.builder("mmul") + .addInputs(mA, mB) + .addOutputs(mC) + .addIntegerArguments(0,0) + .build(); + + Nd4j.getExecutioner().exec(op); + assertEquals(Nd4j.create(1,0,0), mC); + } @Test public void testScalarSqueeze() { @@ -6145,8 +6199,8 @@ public class Nd4jTestsC extends BaseNd4jTest { val vectorN = Nd4j.create(new float[]{1, 2, 3}, new long[]{3}); val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3}); - log.info("vectorN: {}", vectorN); - log.info("vectorL: {}", vectorL); +// log.info("vectorN: {}", vectorN); +// log.info("vectorL: {}", vectorL); val outN = matrix.mmul(vectorN); val outL = matrix.mmul(vectorL); @@ -6193,6 +6247,14 @@ public class Nd4jTestsC extends BaseNd4jTest { } + @Test + public void testScalarPrint_1() { + val scalar = Nd4j.scalar(3.0f); + + Nd4j.exec(new PrintVariable(scalar, true)); + } + + @Test public void testValueArrayOf_1() { val vector = Nd4j.valueArrayOf(new long[] {5}, 2f, DataType.FLOAT); @@ -6564,7 +6626,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testSummaryStatsEquality_1() { - log.info("Datatype: {}", Nd4j.dataType()); +// log.info("Datatype: {}", Nd4j.dataType()); for(boolean biasCorrected : new boolean[]{false, true}) { @@ -6573,9 +6635,9 @@ public class Nd4jTestsC extends BaseNd4jTest { val standardDeviation = new org.apache.commons.math3.stat.descriptive.moment.StandardDeviation(biasCorrected); double std2 = standardDeviation.evaluate(indArray1.data().asDouble()); - log.info("Bias corrected = {}", biasCorrected); - log.info("nd4j std: {}", std); - log.info("apache math3 std: {}", std2); +// log.info("Bias corrected = {}", biasCorrected); +// log.info("nd4j std: {}", std); +// log.info("apache math3 std: {}", std2); assertEquals(std, std2, 1e-5); } @@ -6723,18 +6785,11 @@ public class Nd4jTestsC extends BaseNd4jTest { Nd4j.getExecutioner().commit(); - log.info("Result shape: {}", result.shapeInfoDataBuffer().asLong()); +// log.info("Result shape: {}", result.shapeInfoDataBuffer().asLong()); Nd4j.setDataType(dtype); } - @Test - public void testSomething() { - val a = Nd4j.create(10, 20); - - log.info("Shape: {}", a.mean(0).shape()); - } - @Test public void testTranspose_Custom(){ @@ -6942,6 +6997,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @Ignore public void testMatmul_vs_tf() throws Exception { // uncomment this line to initialize & propagate sgemm/dgemm pointer @@ -7010,7 +7066,7 @@ public class Nd4jTestsC extends BaseNd4jTest { val z = Transforms.greaterThanOrEqual(x, y, true); val str = ez.toString(); - log.info("exp: {}", str); +// log.info("exp: {}", str); assertEquals(ex, x); assertEquals(ey, y); @@ -7124,6 +7180,16 @@ public class Nd4jTestsC extends BaseNd4jTest { } } + @Test + public void testScalarEquality_1() { + val x = Nd4j.scalar(1.0f); + val e = Nd4j.scalar(3.0f); + + x.addi(2.0f); + + assertEquals(e, x); + } + @Test public void testStack(){ INDArray in = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4); @@ -7131,19 +7197,19 @@ public class Nd4jTestsC extends BaseNd4jTest { for( int i=-3; i<3; i++ ){ INDArray out = Nd4j.stack(i, in, in2); - int[] expShape; + long[] expShape; switch (i){ case -3: case 0: - expShape = new int[]{2,3,4}; + expShape = new long[]{2,3,4}; break; case -2: case 1: - expShape = new int[]{3,2,4}; + expShape = new long[]{3,2,4}; break; case -1: case 2: - expShape = new int[]{3,4,2}; + expShape = new long[]{3,4,2}; break; default: throw new RuntimeException(String.valueOf(i)); @@ -7537,6 +7603,7 @@ public class Nd4jTestsC extends BaseNd4jTest { String wsName = "testRollingMeanWs"; try { System.gc(); + int iterations1 = isIntegrationTests() ? 5 : 2; for (int e = 0; e < 5; e++) { try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) { val array = Nd4j.create(DataType.FLOAT, 32, 128, 256, 256); @@ -7544,7 +7611,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } } - int iterations = 20; + int iterations = isIntegrationTests() ? 20 : 3; val timeStart = System.nanoTime(); for (int e = 0; e < iterations; e++) { try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) { @@ -7657,7 +7724,7 @@ public class Nd4jTestsC extends BaseNd4jTest { public void testGetColumnRowVector(){ INDArray arr = Nd4j.create(1,4); INDArray col = arr.getColumn(0); - System.out.println(Arrays.toString(col.shape())); +// System.out.println(Arrays.toString(col.shape())); assertArrayEquals(new long[]{1}, col.shape()); } @@ -7778,7 +7845,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double[] data = new double[]{15.0, 16.0}; INDArray vector = Nd4j.createFromArray(data).reshape(1,2); INDArray slice = vector.slice(0); - System.out.println(slice.shapeInfoToString()); +// System.out.println(slice.shapeInfoToString()); assertEquals(vector.reshape(2), slice); slice.assign(-1); assertEquals(Nd4j.createFromArray(-1.0, -1.0).reshape(1,2), vector); @@ -7787,9 +7854,11 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testSliceMatrix(){ INDArray arr = Nd4j.arange(4).reshape(2,2); - System.out.println(arr.slice(0)); - System.out.println(); - System.out.println(arr.slice(1)); +// System.out.println(arr.slice(0)); +// System.out.println(); +// System.out.println(arr.slice(1)); + arr.slice(0); + arr.slice(1); } @Test @@ -8100,7 +8169,7 @@ public class Nd4jTestsC extends BaseNd4jTest { List l = c.calculateOutputShape(); - System.out.println(Arrays.toString(l.get(0).getShape())); +// System.out.println(Arrays.toString(l.get(0).getShape())); //from [4,4,3] to [2,4,6] then crop to [2,4,5] } @@ -8164,6 +8233,38 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(e, z); } + @Test + public void testScalarEqualsNoResult(){ + INDArray out = Nd4j.exec(new ScalarEquals(Nd4j.createFromArray(-2, -1, 0, 1, 2), null, 0)); + INDArray exp = Nd4j.createFromArray(false, false, true, false, false); + assertEquals(exp, out); + } + + @Test + public void testPutOverwrite(){ + INDArray arr = Nd4j.create(DataType.DOUBLE, 10); + arr.putScalar(0, 10); + System.out.println(arr); + INDArray arr2 = Nd4j.createFromArray(3.0, 3.0, 3.0); + val view = arr.get(new INDArrayIndex[]{NDArrayIndex.interval(1, 4)}); + view.assign(arr2); + System.out.println(arr); + } + + @Test + public void testEmptyReshapingMinus1(){ + INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0); + INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2); + + INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0]; + INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0]; + INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0]; + + assertArrayEquals(new long[]{2, 0, 0}, out0.shape()); + assertArrayEquals(new long[]{0, 1}, out1.shape()); + assertArrayEquals(new long[]{10, 0}, out2.shape()); + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java index feb7131fd..946da0fb3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java @@ -57,13 +57,13 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest { @Before public void before() throws Exception { - super.before(); + super.beforeTest(); DataTypeUtil.setDTypeForContext(DataType.DOUBLE); } @After public void after() throws Exception { - super.after(); + super.afterTest(); DataTypeUtil.setDTypeForContext(initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java index 39f3de324..ef32c1f99 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java @@ -37,8 +37,7 @@ import org.slf4j.LoggerFactory; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Tests comparing Nd4j ops to other libraries @@ -59,7 +58,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { @Before public void before() throws Exception { - super.before(); + super.beforeTest(); DataTypeUtil.setDTypeForContext(DataType.DOUBLE); Nd4j.getRandom().setSeed(SEED); @@ -67,7 +66,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { @After public void after() throws Exception { - super.after(); + super.afterTest(); DataTypeUtil.setDTypeForContext(initialType); } @@ -197,7 +196,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { INDArray gemv = m.mmul(v); RealMatrix gemv2 = rm.multiply(rv); - assertArrayEquals(new int[] {rows, 1}, gemv.shape()); + assertArrayEquals(new long[] {rows, 1}, gemv.shape()); assertArrayEquals(new int[] {rows, 1}, new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java index f1fdf9c57..bd10bd74d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java @@ -48,17 +48,15 @@ public class ShufflesTests extends BaseNd4jTest { array.getRow(x).assign(x); } - System.out.println(array); +// System.out.println(array); OrderScanner2D scanner = new OrderScanner2D(array); assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f); - System.out.println(); - Nd4j.shuffle(array, 1); - System.out.println(array); +// System.out.println(array); ArrayUtil.argMin(new int[] {}); @@ -71,19 +69,12 @@ public class ShufflesTests extends BaseNd4jTest { for (int x = 0; x < 10; x++) { array.getColumn(x).assign(x); } - - System.out.println(array); +// System.out.println(array); OrderScanner2D scanner = new OrderScanner2D(array); - assertArrayEquals(new float[] {0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f}, scanner.getMap(), 0.01f); - - System.out.println(); - Nd4j.shuffle(array, 0); - - System.out.println(array); - +// System.out.println(array); assertTrue(scanner.compareColumn(array)); } @@ -94,20 +85,12 @@ public class ShufflesTests extends BaseNd4jTest { array.getRow(x).assign(x); } - System.out.println(array); - +// System.out.println(array); OrderScanner2D scanner = new OrderScanner2D(array); assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f}, scanner.getMap(), 0.01f); - - System.out.println(); - Nd4j.shuffle(array, 1); - - System.out.println(array); - - ArrayUtil.argMin(new int[] {}); - +// System.out.println(array); assertTrue(scanner.compareRow(array)); } @@ -119,26 +102,21 @@ public class ShufflesTests extends BaseNd4jTest { features.getRow(x).assign(x); labels.getRow(x).assign(x); } - - System.out.println(features); +// System.out.println(features); OrderScanner2D scanner = new OrderScanner2D(features); assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f); - System.out.println(); - List list = new ArrayList<>(); list.add(features); list.add(labels); Nd4j.shuffle(list, 1); - System.out.println(features); - - System.out.println(); - - System.out.println(labels); +// System.out.println(features); +// System.out.println(); +// System.out.println(labels); ArrayUtil.argMin(new int[] {}); @@ -164,24 +142,20 @@ public class ShufflesTests extends BaseNd4jTest { labels.slice(x).assign(x); } - System.out.println(features); +// System.out.println(features); OrderScanner3D scannerFeatures = new OrderScanner3D(features); OrderScanner3D scannerLabels = new OrderScanner3D(labels); - System.out.println(); - List list = new ArrayList<>(); list.add(features); list.add(labels); Nd4j.shuffle(list, 1, 2); - System.out.println(features); - - System.out.println("------------------"); - - System.out.println(labels); +// System.out.println(features); +// System.out.println("------------------"); +// System.out.println(labels); assertTrue(scannerFeatures.compareSlice(features)); assertTrue(scannerLabels.compareSlice(labels)); @@ -360,7 +334,7 @@ public class ShufflesTests extends BaseNd4jTest { } if (Arrays.equals(map, newMap)) { - System.out.println("Maps are equal"); +// System.out.println("Maps are equal"); return false; } @@ -407,7 +381,7 @@ public class ShufflesTests extends BaseNd4jTest { } if (Arrays.equals(map, newMap)) { - System.out.println("Maps are equal"); +// System.out.println("Maps are equal"); return false; } @@ -433,7 +407,7 @@ public class ShufflesTests extends BaseNd4jTest { } if (Arrays.equals(map, newMap)) { - System.out.println("Maps are equal"); +// System.out.println("Maps are equal"); return false; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivationJson.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java similarity index 63% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivationJson.java rename to nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java index 9586a8160..1240c1213 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivationJson.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java @@ -22,7 +22,11 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.activations.impl.*; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.primitives.Pair; import org.nd4j.shade.jackson.databind.*; import java.util.ArrayList; @@ -37,9 +41,9 @@ import static org.junit.Assert.assertEquals; * Created by Alex on 30/12/2016. */ @RunWith(Parameterized.class) -public class TestActivationJson extends BaseNd4jTest { +public class TestActivation extends BaseNd4jTest { - public TestActivationJson(Nd4jBackend backend) { + public TestActivation(Nd4jBackend backend) { super(backend); } @@ -59,6 +63,59 @@ public class TestActivationJson extends BaseNd4jTest { mapper.enable(SerializationFeature.INDENT_OUTPUT); } + @Test + public void testRelu(){ + + Double[] max = {null, 6.0, 2.5, 5.0}; + Double[] threshold = {0.0, 0.0, 0.75, 0.2}; + Double[] negativeSlope = {0.0, 0.0, 0.0, 0.3}; + + INDArray in = Nd4j.linspace(-10, 10, 1000, DataType.DOUBLE); + double[] dIn = in.data().asDouble(); + + for( int i=0; i 5000); } @@ -559,10 +558,10 @@ public class IndexingTestsC extends BaseNd4jTest { INDArray arr = Nd4j.linspace(DataType.FLOAT, 1, prod, prod).reshape('c', inShape).dup(order); INDArray sub = arr.get(indexes); - System.out.println(Arrays.toString(indexes)); - System.out.println(arr); - System.out.println(); - System.out.println(sub); +// System.out.println(Arrays.toString(indexes)); +// System.out.println(arr); +// System.out.println(); +// System.out.println(sub); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java index 00f369756..80ebb203c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java @@ -27,8 +27,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.PointIndex; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java index a25a31425..3f18ee5cb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java @@ -25,6 +25,8 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.NDArrayIndex; +import static org.junit.Assert.assertArrayEquals; + /** * @author Adam Gibson */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java index b74d0b1b4..d106dca0a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java @@ -24,6 +24,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.NDArrayIndex; +import static org.junit.Assert.assertArrayEquals; + /** * @author Adam Gibson */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java index 65fe355fe..10a640c47 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java @@ -24,6 +24,8 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; + /** * @author Adam Gibson */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java index 8e1d8bd8f..a44df5868 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java @@ -57,7 +57,7 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTest { public void compareAfterWrite() throws Exception { int [] ranksToCheck = new int[] {0,1,2,3,4}; for (int i=0; i lsd = op.calculateOutputShape(); assertEquals(1, lsd.size()); assertArrayEquals(new long[]{8, 8, 3}, lsd.get(0).getShape()); @@ -1331,7 +1398,7 @@ public class CustomOpsTests extends BaseNd4jTest { INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3); val c = Conditions.equals(0.0); - System.out.println("Y:\n" + y); +// System.out.println("Y:\n" + y); INDArray z = x.match(y, c); INDArray exp = Nd4j.createFromArray(new boolean[][]{ @@ -1353,4 +1420,260 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, result); } + + // Exact copy of libnd4j test + @Test + @Ignore + public void testRgbToHsv() { + INDArray expected = Nd4j.createFromArray(new float[]{ + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f, + 0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f, + 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f, + 0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f, + 0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f, + 0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f, + 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f, + 0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f, + 0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f, + 0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f + }).reshape(5,4,3); + INDArray input = Nd4j.createFromArray(new float[]{ + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f, + 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f, + 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f, + 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f, + 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f, + 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f, + 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f, + 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f, + 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f, + 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f + }).reshape(5,4,3); + RgbToHsv op = new RgbToHsv(input); + INDArray[] ret = Nd4j.exec(op); + assertEquals(ret[0], expected); + } + + // Exact copy of libnd4j test + @Test + public void testHsvToRgb() { + INDArray input = Nd4j.createFromArray(new float[]{0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f}).reshape(4,3); + + INDArray expected = Nd4j.createFromArray(new float[]{0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f}).reshape(4,3); + + HsvToRgb op = new HsvToRgb(input); + INDArray[] ret = Nd4j.exec(op); + assertEquals(ret[0], expected); + } + + @Test + public void testHsvToRgb_1() { + /* Emulation of simple TF test: + image = tf.random_uniform(shape = [1,1,3]) + tf.image.hsv_to_rgb(image)*/ + INDArray image = Nd4j.createFromArray(new float[]{0.778785586f,0.801197767f,0.724374652f}). + reshape(1,1,3); + HsvToRgb op = new HsvToRgb(image); + INDArray[] ret = Nd4j.exec(op); + System.out.println(ret[0].toStringFull()); + INDArray expected = Nd4j.createFromArray(new float[]{ 0.53442812f, 0.144007325f, 0.724374652f}).reshape(1,1,3); + assertEquals(expected, ret[0]); + } + + @Test + public void testRgbToHsv_1() { + /* Emulation of simple TF test: + image = tf.random_uniform(shape = [1,2,3]) + tf.image.rgb_to_hsv(image)*/ + INDArray image = Nd4j.createFromArray(new float[]{0.778785586f,0.801197767f,0.724374652f, + 0.230894327f, 0.727141261f, 0.180390716f }).reshape(2,3); + RgbToHsv op = new RgbToHsv(image); + INDArray[] ret = Nd4j.exec(op); + INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f,0.095885336f,0.801197767f, + 0.317938268f,0.751917899f,0.727141261f}).reshape(2,3); + assertEquals(expected, ret[0]); + } + + @Test + public void testLu() { + INDArray input = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7.f}) + .reshape(3,3); + Lu op = new Lu(input); + INDArray[] ret = Nd4j.exec(op); + + INDArray expected = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7f}).reshape(3,3); + assertEquals(expected, ret[0]); + } + + @Test + public void testRgbToYiq() { + INDArray image = Nd4j.createFromArray(new float[]{ + 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, + 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , + 0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f, + 0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f , + 0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f, + 0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f, + 0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f, + 0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f, + 0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f, + 0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f , + 0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f, + 0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f + }).reshape(5,4,3); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, + 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, + -0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f , + 0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f, + 0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f, + 0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f, + 0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f , + -0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f, + 0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f, + 0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f, + -0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f, + -0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f + }).reshape(5,4,3); + + RgbToYiq op = new RgbToYiq(image); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Test + public void testYiqToRgb() { + INDArray image = Nd4j.createFromArray(new float[]{ + 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, + 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, + -0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f, + 0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f, + 0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f, + 0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f, + 0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f, + -0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f, + -0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f, + 0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f, + 0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f, + 0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f + }).reshape(5,4,3); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, + 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, + 0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f, + -0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f, + 0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f, + 0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f, + 0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f, + 0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f, + 1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f, + 0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f, + 1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f, + -0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f + }).reshape(5,4,3); + + YiqToRgb op = new YiqToRgb(image); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Test + public void testRgbToGrayscale() { + INDArray image = Nd4j.createFromArray(new float[]{ + 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, + 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, + 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, + 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, + 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, + 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, + 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f, + -2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, + -4.8125e+01f + }).reshape(5,4,3); + + INDArray expected = Nd4j.createFromArray(new float[]{ + -47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, + -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, + -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, + 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f + }).reshape(5,4,1); + + RgbToGrayscale op = new RgbToGrayscale(image); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Test + public void testRgbToYuv() { + INDArray image = Nd4j.createFromArray(new float[]{ + 10f,50f,200f + }); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 55.14f , 71.2872001f, -39.6005542f + }); + + RgbToYuv op = new RgbToYuv(image); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Test + public void testYuvToRgb() { + INDArray image = Nd4j.createFromArray(new float[]{ + 55.14f , 71.2872001f, -39.6005542f + }); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 10f, 50f, 200f + }); + YuvToRgb op = new YuvToRgb(image); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Test + public void testRgbToYiqEmpty() { + INDArray image = Nd4j.create(0,4,3); + RgbToYiq op = new RgbToYiq(image); + INDArray[] ret = Nd4j.exec(op); + assertArrayEquals(image.shape(), ret[0].shape()); + } + + @Test + public void testTriangularSolve() { + INDArray a = Nd4j.createFromArray(new float[]{ + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + }).reshape(4, 4); + + INDArray b = Nd4j.createFromArray(new float[]{ + 4.f, 2.f, 4.f, 2.f + }).reshape(4, 1); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f + }).reshape(4, 1); + + val op = new TriangularSolve(a, b, true, false); + INDArray[] ret = Nd4j.exec(op); + + assertEquals(expected, ret[0]); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java new file mode 100644 index 000000000..38a5ab763 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.custom; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.compat.CompatStringSplit; +import org.nd4j.linalg.api.ops.util.PrintVariable; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * This is special test suit: we test operations that on C++ side modify arrays that come from Java + */ +@Slf4j +public class ExpandableOpsTests { + + @Test + public void testCompatStringSplit_1() throws Exception { + val array = Nd4j.create("first string", "second"); + val delimiter = Nd4j.create(" "); + + val exp0 = Nd4j.createFromArray(new long[] {0,0, 0,1, 1,0}); + val exp1 = Nd4j.create("first", "string", "second"); + + val results = Nd4j.exec(new CompatStringSplit(array, delimiter)); + assertNotNull(results); + assertEquals(2, results.length); + + assertEquals(exp0, results[0]); + assertEquals(exp1, results[1]); + } + + @Test + public void test() { + val arr = Nd4j.createFromArray(0, 1, 2, 3, 4, 5, 6, 7, 8).reshape(3, 3); + Nd4j.exec(new PrintVariable(arr)); + + val row = arr.getRow(1); + Nd4j.exec(new PrintVariable(row)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java index f6098dd3d..a62dc631e 100755 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java @@ -482,7 +482,7 @@ public class DataSetTest extends BaseNd4jTest { //Tests merging of different CNN masks: [mb,1,h,1], [mb,1,1,w], [mb,1,h,w] for( int t=0; t<3; t++) { - log.info("Starting test: {}", t); +// log.info("Starting test: {}", t); int nOut = 3; int width = 5; int height = 4; @@ -808,7 +808,7 @@ public class DataSetTest extends BaseNd4jTest { ds.shuffle(); INDArray fCol = f.getColumn(0); INDArray lCol = l.getColumn(0); - System.out.println(fCol + "\t" + ds.getExampleMetaData()); +// System.out.println(fCol + "\t" + ds.getExampleMetaData()); for (int j = 0; j < nExamples; j++) { int fVal = (int) fCol.getDouble(j); int lVal = (int) lCol.getDouble(j); @@ -836,7 +836,8 @@ public class DataSetTest extends BaseNd4jTest { public void testToString() { org.nd4j.linalg.dataset.api.DataSet ds = new DataSet(); //this should not throw a null pointer - System.out.println(ds); +// System.out.println(ds); + ds.toString(); //Checking printing of masks int numExamples = 10; @@ -853,7 +854,8 @@ public class DataSetTest extends BaseNd4jTest { } ds = DataSet.merge(list); - System.out.println(ds); +// System.out.println(ds); + ds.toString(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java index a28c026cc..981495eac 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java @@ -149,16 +149,15 @@ public class TestPCA extends BaseNd4jTest { INDArray reduced100 = myPCA.reducedBasis(1.0); assertTrue("100% variance coverage should include all eigenvectors", reduced100.columns() == m.columns()); NDArrayStrings ns = new NDArrayStrings(5); - System.out.println("Eigenvectors:\n" + ns.format(myPCA.getEigenvectors())); - System.out.println("Eigenvalues:\n" + ns.format(myPCA.getEigenvalues())); +// System.out.println("Eigenvectors:\n" + ns.format(myPCA.getEigenvectors())); +// System.out.println("Eigenvalues:\n" + ns.format(myPCA.getEigenvalues())); double variance = 0.0; // sample 1000 of the randomly generated samples with the reduced basis set for (long i = 0; i < 1000; i++) variance += myPCA.estimateVariance(m.getRow(i), reduced70.columns()); variance /= 1000.0; - System.out.println("Fraction of variance using 70% variance with " + reduced70.columns() + " columns: " - + variance); + System.out.println("Fraction of variance using 70% variance with " + reduced70.columns() + " columns: " + variance); assertTrue("Variance does not cover intended 70% variance", variance > 0.70); // create "dummy" data with the same exact trends INDArray testSample = myPCA.generateGaussianSamples(10000); @@ -171,8 +170,8 @@ public class TestPCA extends BaseNd4jTest { 0.5 * myPCA.getEigenvalues().columns())); assertTrue("Eigenvectors are not close enough", myPCA.getEigenvectors() .equalsWithEps(analyzePCA.getEigenvectors(), 0.1 * analyzePCA.getEigenvectors().length())); - System.out.println("Original cov:\n" + ns.format(myPCA.getCovarianceMatrix()) + "\nDummy cov:\n" - + ns.format(analyzePCA.getCovarianceMatrix())); +// System.out.println("Original cov:\n" + ns.format(myPCA.getCovarianceMatrix()) + "\nDummy cov:\n" +// + ns.format(analyzePCA.getCovarianceMatrix())); INDArray testSample2 = analyzePCA.convertBackToFeatures(analyzePCA.convertToComponents(testSample)); assertTrue("Transformation does not work.", testSample.equalsWithEps(testSample2, 1e-5 * testSample.length())); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java index ed93fdc38..a8c211944 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java @@ -34,8 +34,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.ArrayList; import java.util.Arrays; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; import static org.nd4j.linalg.dimensionalityreduction.RandomProjection.johnsonLindenStraussMinDim; import static org.nd4j.linalg.dimensionalityreduction.RandomProjection.targetShape; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java index c1c64b896..8ddd47eb7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java @@ -41,6 +41,7 @@ import java.util.Arrays; import java.util.List; import java.util.UUID; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** @@ -229,7 +230,7 @@ public class Nd4jTest extends BaseNd4jTest { ByteBuffer floatBuffer = pointer1.asByteBuffer(); byte[] dataTwo = new byte[floatBuffer.capacity()]; floatBuffer.get(dataTwo); - Assert.assertArrayEquals(originalData,dataTwo); + assertArrayEquals(originalData,dataTwo); floatBuffer.position(0); DataBuffer dataBuffer = Nd4j.createBuffer(new FloatPointer(floatBuffer.asFloatBuffer()),linspace.length(), DataType.FLOAT); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java index 324efed0e..49d079529 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java @@ -139,11 +139,8 @@ public class BooleanIndexingTest extends BaseNd4jTest { @Test public void test2dAnd2() { INDArray array = Nd4j.zeros(10, 10); - array.slice(4).putScalar(2, 1e-5f); - - - System.out.println(array); +// System.out.println(array); assertFalse(BooleanIndexing.and(array, Conditions.equals(0f))); @@ -329,7 +326,7 @@ public class BooleanIndexingTest extends BaseNd4jTest { boolean result[] = BooleanIndexing.and(array, Conditions.equals(0.0), 1); boolean comp[] = new boolean[] {false, false, true}; - System.out.println("Result: " + Arrays.toString(result)); +// System.out.println("Result: " + Arrays.toString(result)); assertArrayEquals(comp, result); } @@ -338,12 +335,12 @@ public class BooleanIndexingTest extends BaseNd4jTest { INDArray array = Nd4j.ones(3, 10); array.getRow(2).assign(0.0).putScalar(0, 1.0); - System.out.println("Array: " + array); +// System.out.println("Array: " + array); boolean result[] = BooleanIndexing.or(array, Conditions.lessThan(0.9), 1); boolean comp[] = new boolean[] {false, false, true}; - System.out.println("Result: " + Arrays.toString(result)); +// System.out.println("Result: " + Arrays.toString(result)); assertArrayEquals(comp, result); } @@ -355,7 +352,7 @@ public class BooleanIndexingTest extends BaseNd4jTest { boolean result[] = BooleanIndexing.and(array, Conditions.lessThan(0.0), 1); boolean comp[] = new boolean[] {false, false, false}; - System.out.println("Result: " + Arrays.toString(result)); +// System.out.println("Result: " + Arrays.toString(result)); assertArrayEquals(comp, result); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java index cf8154ab4..d8c17a70c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java @@ -31,6 +31,7 @@ import org.nd4j.linalg.util.DeviceLocalNDArray; import java.util.Arrays; import java.util.concurrent.atomic.AtomicInteger; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @Slf4j diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java index e92f03c39..15fbac932 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java @@ -453,8 +453,8 @@ public class MixedDataTypesTests extends BaseNd4jTest { INDArray not = Transforms.not(asBool); // INDArray asFloat = not.castTo(DataType.FLOAT); - System.out.println(not); - System.out.println(asFloat); +// System.out.println(not); +// System.out.println(asFloat); INDArray exp = Nd4j.ones(DataType.FLOAT, 3, 5000); assertEquals(DataType.FLOAT, exp.dataType()); assertEquals(exp.dataType(), asFloat.dataType()); @@ -480,7 +480,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { INDArray arr = Nd4j.scalar(dt, 10.0); arr.assign(2.0); - System.out.println(dt + " - value: " + arr + " - " + arr.getDouble(0)); +// System.out.println(dt + " - value: " + arr + " - " + arr.getDouble(0)); } } @@ -488,17 +488,23 @@ public class MixedDataTypesTests extends BaseNd4jTest { public void testSimple(){ Nd4j.create(1); for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) { - System.out.println("----- " + dt + " -----"); +// System.out.println("----- " + dt + " -----"); INDArray arr = Nd4j.ones(dt,1, 5); - System.out.println("Ones: " + arr); +// System.out.println("Ones: " + arr); arr.assign(1.0); - System.out.println("assign(1.0): " + arr); - System.out.println("DIV: " + arr.div(8)); - System.out.println("MUL: " + arr.mul(8)); - System.out.println("SUB: " + arr.sub(8)); - System.out.println("ADD: " + arr.add(8)); - System.out.println("RDIV: " + arr.rdiv(8)); - System.out.println("RSUB: " + arr.rsub(8)); +// System.out.println("assign(1.0): " + arr); +// System.out.println("DIV: " + arr.div(8)); +// System.out.println("MUL: " + arr.mul(8)); +// System.out.println("SUB: " + arr.sub(8)); +// System.out.println("ADD: " + arr.add(8)); +// System.out.println("RDIV: " + arr.rdiv(8)); +// System.out.println("RSUB: " + arr.rsub(8)); + arr.div(8); + arr.mul(8); + arr.sub(8); + arr.add(8); + arr.rdiv(8); + arr.rsub(8); } } @@ -519,7 +525,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { val boolAttached = bool.isAttached(); val doubleAttached = dbl.isAttached(); - System.out.println(i + "\tboolAttached=" + boolAttached + ", doubleAttached=" + doubleAttached ); +// System.out.println(i + "\tboolAttached=" + boolAttached + ", doubleAttached=" + doubleAttached ); //System.out.println("bool: " + bool); //java.lang.IllegalStateException: Indexer must never be null //System.out.println("double: " + dbl); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java new file mode 100644 index 000000000..27cdf5a42 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java @@ -0,0 +1,83 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.multithreading; + +import lombok.val; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.concurrent.CopyOnWriteArrayList; + +import static org.junit.Assert.assertEquals; + +/** + * @author raver119@gmail.com + */ +public class MultithreadedTests { + + @Test + public void basicMigrationTest_1() throws Exception { + if (Nd4j.getAffinityManager().getNumberOfDevices() < 2) + return; + + val exp = Nd4j.create(DataType.INT32, 5, 5).assign(2); + + val hash = new HashSet(); + + // we're creating bunch of arrays on different devices + val list = new ArrayList(); + for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { + val t = e; + val thread = new Thread(new Runnable() { + @Override + public void run() { + for (int f = 0; f < 10; f++) { + val array = Nd4j.create(DataType.INT32, 5, 5).assign(1); + + // store current deviceId for further validation + hash.add(Nd4j.getAffinityManager().getDeviceForCurrentThread()); + + // make sure INDArray has proper affinity set + assertEquals(Nd4j.getAffinityManager().getDeviceForCurrentThread(), Nd4j.getAffinityManager().getDeviceForArray(array)); + + list.add(array); + } + }; + }); + + thread.start(); + thread.join(); + } + + // lets make sure all devices covered + assertEquals(Nd4j.getAffinityManager().getNumberOfDevices(), hash.size()); + + // make sure nothing failed in threads + assertEquals(10 * Nd4j.getAffinityManager().getNumberOfDevices(), list.size()); + + // now we're going to use arrays on current device, so data will be migrated + for (val arr:list) { + arr.addi(1); + + assertEquals(exp, arr); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java index 55d90a6aa..e8c016485 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java @@ -256,7 +256,7 @@ public class NativeBlasTests extends BaseNd4jTest { val exp = A.mmul(B); - log.info("exp: {}", exp); +// log.info("exp: {}", exp); // ? assertEquals(exp, res); @@ -284,7 +284,7 @@ public class NativeBlasTests extends BaseNd4jTest { val exp = A.mmul(B); - log.info("exp mean: {}", exp.meanNumber()); +// log.info("exp mean: {}", exp.meanNumber()); // ? assertEquals(exp, res); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java index ca9de0252..454739496 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java @@ -59,9 +59,9 @@ public class OpsMappingTests extends BaseNd4jTest { return 'c'; } - @Test - public void testCustomOpsMapping() { - Nd4j.create(1); + @Override + public long getTimeoutMilliseconds() { + return 90000L; } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java index 5a51b847d..1a70fa6c1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java @@ -176,9 +176,9 @@ public class DerivativeTests extends BaseNd4jTest { INDArray z = Transforms.hardSigmoid(xArr, true); INDArray zPrime = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(xArr.dup())); - System.out.println(xArr); - System.out.println(z); - System.out.println(zPrime); +// System.out.println(xArr); +// System.out.println(z); +// System.out.println(zPrime); for (int i = 0; i < expHSOut.length; i++) { double relErrorHS = diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java index e04250f69..0fc085abe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java @@ -111,7 +111,7 @@ public class OpExecutionerTests extends BaseNd4jTest { new EuclideanDistance(distanceInputRow, distanceComp, result, 0)); INDArray euclideanAssertion = Nd4j.ones(4).castTo(DataType.DOUBLE); assertEquals(euclideanAssertion, result); - System.out.println(result); +// System.out.println(result); } @@ -309,7 +309,7 @@ public class OpExecutionerTests extends BaseNd4jTest { val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); } @@ -517,10 +517,10 @@ public class OpExecutionerTests extends BaseNd4jTest { 0.27320877, 0.29476917, 0.29449323, 0.29720396, 0.31319344, 0.2803108, 0.28671616, 0.30462897, 0.3049033, 0.29277474, 0.29136384, 0.30316526, 0.2807459}, new int[] {150, 3}, 'f'); - System.out.println("Data:" + input.data().length()); +// System.out.println("Data:" + input.data().length()); val softMax = new SoftMax(input); Nd4j.getExecutioner().exec((CustomOp) softMax); - assertEquals(assertion, softMax.outputArguments()[0]); + assertEquals(assertion, softMax.outputArguments().get(0)); } @@ -559,7 +559,7 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); } @Test @@ -589,12 +589,12 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test public void testMeanSumSimple() { - System.out.println("3d"); +// System.out.println("3d"); INDArray arr = Nd4j.ones(1, 4, 4); assertEquals(Nd4j.ones(1), arr.mean(1, 2)); assertEquals(Nd4j.ones(1).muli(16), arr.sum(1, 2)); - System.out.println("4d"); +// System.out.println("4d"); INDArray arr4 = Nd4j.ones(1, 1, 4, 4); INDArray arr4m = arr4.mean(2, 3); INDArray arr4s = arr4.sum(2, 3); @@ -603,7 +603,7 @@ public class OpExecutionerTests extends BaseNd4jTest { for (int i = 0; i < arr4s.length(); i++) assertEquals(arr4s.getDouble(i), 16, 1e-1); - System.out.println("5d"); +// System.out.println("5d"); INDArray arr5 = Nd4j.ones(1, 1, 4, 4, 4); INDArray arr5m = arr5.mean(2, 3); INDArray arr5s = arr5.sum(2, 3); @@ -611,7 +611,7 @@ public class OpExecutionerTests extends BaseNd4jTest { assertEquals(arr5m.getDouble(i), 1, 1e-1); for (int i = 0; i < arr5s.length(); i++) assertEquals(arr5s.getDouble(i), 16, 1e-1); - System.out.println("6d"); +// System.out.println("6d"); INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4); INDArray arr6Tad = arr6.tensorAlongDimension(0, 2, 3); INDArray arr6s = arr6.sum(2, 3); @@ -629,7 +629,7 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4); INDArray arr6s = arr6.sum(2, 3); - System.out.println("Arr6s: " + arr6.length()); +// System.out.println("Arr6s: " + arr6.length()); for (int i = 0; i < arr6s.length(); i++) assertEquals(16, arr6s.getDouble(i), 1e-1); } @@ -659,10 +659,10 @@ public class OpExecutionerTests extends BaseNd4jTest { } assertEquals("Failed for [" + order + "] order", exp, arr6s); - System.out.println("ORDER: " + order); - for (int i = 0; i < 6; i++) { - System.out.println(arr6s.getDouble(i)); - } +// System.out.println("ORDER: " + order); +// for (int i = 0; i < 6; i++) { +// System.out.println(arr6s.getDouble(i)); +// } } } finally { Nd4j.factory().setOrder(origOrder); @@ -727,8 +727,8 @@ public class OpExecutionerTests extends BaseNd4jTest { DropOut dropOut = new DropOut(array, result, 0.05); Nd4j.getExecutioner().exec(dropOut); - System.out.println("Src array: " + array); - System.out.println("Res array: " + result); +// System.out.println("Src array: " + array); +// System.out.println("Res array: " + result); assertNotEquals(array, result); } @@ -741,8 +741,8 @@ public class OpExecutionerTests extends BaseNd4jTest { DropOutInverted dropOut = new DropOutInverted(array, result, 0.65); Nd4j.getExecutioner().exec(dropOut); - System.out.println("Src array: " + array); - System.out.println("Res array: " + result); +// System.out.println("Src array: " + array); +// System.out.println("Res array: " + result); assertNotEquals(array, result); } @@ -778,8 +778,8 @@ public class OpExecutionerTests extends BaseNd4jTest { assertEquals(5, result.columns()); assertEquals(assertion, result); - System.out.println(assertion.toString()); - System.out.println(result.toString()); +// System.out.println(assertion.toString()); +// System.out.println(result.toString()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index 72be040c5..4e16544b8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -126,7 +126,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test public void testBroadcastMultiDim() { INDArray data = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(2, 3, 5); - System.out.println(data); +// System.out.println(data); INDArray mask = Nd4j.create(new double[][] {{1.00, 1.00, 1.00, 1.00, 1.00}, {1.00, 1.00, 1.00, 0.00, 0.00}}); Nd4j.getExecutioner().exec(new BroadcastMulOp(data, mask, data, 0, 2)); INDArray assertion = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, @@ -326,7 +326,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); } @Test @@ -342,7 +342,8 @@ public class OpExecutionerTestsC extends BaseNd4jTest { public void testTad() { INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(2, 3, 2); for (int i = 0; i < arr.tensorsAlongDimension(0); i++) { - System.out.println(arr.tensorAlongDimension(i, 0)); +// System.out.println(arr.tensorAlongDimension(i, 0)); + arr.tensorAlongDimension(i, 0); } } @@ -425,12 +426,12 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); val softmax = new SoftMax(linspace.dup()); Nd4j.getExecutioner().exec((CustomOp) softmax); - assertEquals(linspace.rows(), softmax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); + assertEquals(linspace.rows(), softmax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); } @@ -439,7 +440,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); val max = new SoftMax(linspace); Nd4j.getExecutioner().exec((CustomOp) max); - linspace.assign(max.outputArguments()[0]); + linspace.assign(max.outputArguments().get(0)); assertEquals(getFailureMessage(), linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1); } @@ -503,12 +504,12 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test public void testMeanSumSimple() { - System.out.println("3d"); +// System.out.println("3d"); INDArray arr = Nd4j.ones(1, 4, 4); assertEquals(Nd4j.ones(1), arr.mean(1, 2)); assertEquals(Nd4j.ones(1).muli(16), arr.sum(1, 2)); - System.out.println("4d"); +// System.out.println("4d"); INDArray arr4 = Nd4j.ones(1, 1, 4, 4); INDArray arr4m = arr4.mean(2, 3); INDArray arr4s = arr4.sum(2, 3); @@ -516,7 +517,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { assertEquals(arr4m.getDouble(i), 1, 1e-1); for (int i = 0; i < arr4s.length(); i++) assertEquals(arr4s.getDouble(i), 16, 1e-1); - System.out.println("5d"); +// System.out.println("5d"); INDArray arr5 = Nd4j.ones(1, 1, 4, 4, 4); INDArray arr5s = arr5.sum(2, 3); for (int i = 0; i < arr5s.length(); i++) @@ -525,7 +526,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { for (int i = 0; i < arr5m.length(); i++) assertEquals(1, arr5m.getDouble(i), 1e-1); - System.out.println("6d"); +// System.out.println("6d"); INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4); INDArray arr6m = arr6.mean(2, 3); for (int i = 0; i < arr6m.length(); i++) @@ -590,17 +591,17 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test public void testSum5d() throws Exception { - System.out.println("5d"); +// System.out.println("5d"); INDArray arr5 = Nd4j.ones(1, 1, 4, 4, 4); INDArray arr5s = arr5.sum(2, 3); Thread.sleep(1000); - System.out.println("5d length: " + arr5s.length()); +// System.out.println("5d length: " + arr5s.length()); for (int i = 0; i < arr5s.length(); i++) assertEquals(16, arr5s.getDouble(i), 1e-1); INDArray arrF = Nd4j.ones(1, 1, 4, 4, 4); - System.out.println("A: " + arrF); +// System.out.println("A: " + arrF); } @@ -643,9 +644,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray cOrder = Nd4j.create(new int[] {2, 2}, 'c').assign(toAssign); INDArray fOrder = Nd4j.create(new int[] {2, 2}, 'f').assign(toAssign); - System.out.println(cOrder); - System.out.println(cOrder.sum(0)); //[2,4] -> correct - System.out.println(fOrder.sum(0)); //[2,3] -> incorrect +// System.out.println(cOrder); +// System.out.println(cOrder.sum(0)); //[2,4] -> correct +// System.out.println(fOrder.sum(0)); //[2,3] -> incorrect assertEquals(cOrder, fOrder); assertEquals(cOrder.sum(0), fOrder.sum(0)); @@ -908,7 +909,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { assertEquals(xDup, x); - log.info("bins: {}", z); +// log.info("bins: {}", z); assertEquals(zExp, z); } @@ -931,8 +932,8 @@ public class OpExecutionerTestsC extends BaseNd4jTest { expManhattanDistance += Math.abs(diff); } double expectedEuclidean = Math.sqrt(sumSquaredDiff); - System.out.println("Expected, Euclidean: " + expectedEuclidean); - System.out.println("Expected, Manhattan: " + expManhattanDistance); +// System.out.println("Expected, Euclidean: " + expectedEuclidean); +// System.out.println("Expected, Manhattan: " + expManhattanDistance); int mb = 2; INDArray firstOrig = Nd4j.create(mb, 2, 2, 2); @@ -959,14 +960,14 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray outManhattan = Nd4j.getExecutioner().exec(new ManhattanDistance(first, second, 1, 2, 3)); - System.out.println("\n\nOrder: " + order); - System.out.println("Euclidean:"); +// System.out.println("\n\nOrder: " + order); +// System.out.println("Euclidean:"); //System.out.println(Arrays.toString(out.getRow(0).dup().data().asDouble())); //System.out.println(Arrays.toString(out.getRow(1).dup().data().asDouble())); assertEquals(out.getDouble(0), out.getDouble(1), 1e-5); - System.out.println("Manhattan:"); +// System.out.println("Manhattan:"); //System.out.println(Arrays.toString(outManhattan.getRow(0).dup().data().asDouble())); //System.out.println(Arrays.toString(outManhattan.getRow(1).dup().data().asDouble())); @@ -1017,7 +1018,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { for (int i = 0; i < 32; i++) { INDArray tensor = array.tensorAlongDimension(i, 1, 2); - log.info("tad {}: {}", i, array.getDouble(0)); +// log.info("tad {}: {}", i, array.getDouble(0)); assertEquals((float) (100 + i) * (100 * 100), tensor.sumNumber().floatValue(), 0.001f); assertEquals((float) 100 + i, tensor.meanNumber().floatValue(), 0.001f); } @@ -1076,7 +1077,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray pile = Nd4j.pile(arrays); - log.info("Pile: {}", pile); +// log.info("Pile: {}", pile); INDArray[] tears = Nd4j.tear(pile, 1, 2); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java index d0c61de9b..95bd8a649 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java @@ -125,7 +125,7 @@ public class OperationProfilerTests extends BaseNd4jTest { OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processOperands(x, y); - log.info("Causes: {}", Arrays.toString(causes)); +// log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.MIXED_ORDER)); //assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.NON_EWS_ACCESS)); @@ -139,7 +139,7 @@ public class OperationProfilerTests extends BaseNd4jTest { OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processOperands(x, y, z); - log.info("Causes: {}", Arrays.toString(causes)); +// log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.MIXED_ORDER)); //assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.NON_EWS_ACCESS)); @@ -154,7 +154,7 @@ public class OperationProfilerTests extends BaseNd4jTest { OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processOperands(w, x, y, z); - log.info("Causes: {}", Arrays.toString(causes)); +// log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.MIXED_ORDER)); } @@ -167,7 +167,7 @@ public class OperationProfilerTests extends BaseNd4jTest { OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processOperands(x, y); - log.info("Causes: {}", Arrays.toString(causes)); +// log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.STRIDED_ACCESS)); } @@ -181,7 +181,7 @@ public class OperationProfilerTests extends BaseNd4jTest { OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); - log.info("Causes: {}", Arrays.toString(causes)); +// log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_NON_EWS_ACCESS)); } @@ -195,7 +195,7 @@ public class OperationProfilerTests extends BaseNd4jTest { OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); - log.info("Causes: {}", Arrays.toString(causes)); +// log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_NON_EWS_ACCESS)); } @@ -211,7 +211,7 @@ public class OperationProfilerTests extends BaseNd4jTest { OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); - log.info("Causes: {}", Arrays.toString(causes)); +// log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_NON_EWS_ACCESS)); } @@ -225,8 +225,8 @@ public class OperationProfilerTests extends BaseNd4jTest { OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); - log.info("TAD: {}", Arrays.toString(pair.getFirst().asInt())); - log.info("Causes: {}", Arrays.toString(causes)); +// log.info("TAD: {}", Arrays.toString(pair.getFirst().asInt())); +// log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.NONE)); } @@ -239,8 +239,8 @@ public class OperationProfilerTests extends BaseNd4jTest { OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); - log.info("TAD: {}", Arrays.toString(pair.getFirst().asInt())); - log.info("Causes: {}", Arrays.toString(causes)); +// log.info("TAD: {}", Arrays.toString(pair.getFirst().asInt())); +// log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_STRIDED_ACCESS)); } @@ -412,7 +412,7 @@ public class OperationProfilerTests extends BaseNd4jTest { val avgA = (nanosB - nanosA) / iterations; - log.info("A: {}; B: {}", avgA, avgB); +// log.info("A: {}; B: {}", avgA, avgB); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java index 4d48a6a98..093fc2ac1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -105,6 +106,7 @@ public class PerformanceTrackerTests extends BaseNd4jTest { } @Test + @Ignore public void testTrackerCpu_1() { if (!Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("native")) return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index ed8f4d441..8a06bd7e9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -31,8 +31,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; -import org.nd4j.linalg.api.ops.random.custom.DistributionUniform; -import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; +import org.nd4j.linalg.api.ops.random.custom.*; import org.nd4j.linalg.api.ops.random.impl.*; import org.nd4j.linalg.api.rng.DefaultRandom; import org.nd4j.linalg.api.rng.Random; @@ -90,7 +89,7 @@ public class RandomTests extends BaseNd4jTest { Nd4j.createUninitialized(shape, Nd4j.order()), mean, standardDeviation), Nd4j.getRandom()); - log.info("arr: {}", arr.data().asDouble()); +// log.info("arr: {}", arr.data().asDouble()); assertEquals(exp, arr); } @@ -107,8 +106,8 @@ public class RandomTests extends BaseNd4jTest { UniformDistribution distribution2 = new UniformDistribution(z2, 1.0, 2.0); Nd4j.getExecutioner().exec(distribution2, random2); - System.out.println("Data: " + z1); - System.out.println("Data: " + z2); +// System.out.println("Data: " + z1); +// System.out.println("Data: " + z2); for (int e = 0; e < z1.length(); e++) { double val = z1.getDouble(e); assertTrue(val >= 1.0 && val <= 2.0); @@ -137,8 +136,8 @@ public class RandomTests extends BaseNd4jTest { log.info("States cpu: {}/{}", random1.rootState(), random1.nodeState()); - System.out.println("Data: " + z1); - System.out.println("Data: " + z2); +// System.out.println("Data: " + z1); +// System.out.println("Data: " + z2); for (int e = 0; e < z1.length(); e++) { double val = z1.getDouble(e); assertTrue(val >= 1.0 && val <= 2.0); @@ -158,8 +157,8 @@ public class RandomTests extends BaseNd4jTest { UniformDistribution distribution2 = new UniformDistribution(z2, 1.0, 2.0); Nd4j.getExecutioner().exec(distribution2, random1); - System.out.println("Data: " + z1); - System.out.println("Data: " + z2); +// System.out.println("Data: " + z1); +// System.out.println("Data: " + z2); assertNotEquals(z1, z2); } @@ -405,7 +404,7 @@ public class RandomTests extends BaseNd4jTest { Distribution nd = new NormalDistribution(random1, 0.0, 1.0); Nd4j.sort(z1, true); - System.out.println("Data for Anderson-Darling: " + z1); +// System.out.println("Data for Anderson-Darling: " + z1); for (int i = 0; i < n; i++) { @@ -435,7 +434,7 @@ public class RandomTests extends BaseNd4jTest { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); - log.info("1: ----------------"); +// log.info("1: ----------------"); INDArray z0 = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(DataType.DOUBLE, 1000000), 0.0, 1.0)); @@ -444,7 +443,7 @@ public class RandomTests extends BaseNd4jTest { random1.setSeed(119); - log.info("2: ----------------"); +// log.info("2: ----------------"); INDArray z1 = Nd4j.zeros(DataType.DOUBLE, 55000000); INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000); @@ -452,16 +451,16 @@ public class RandomTests extends BaseNd4jTest { GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0); Nd4j.getExecutioner().exec(op1, random1); - log.info("2: ----------------"); +// log.info("2: ----------------"); //log.info("End: [{}, {}, {}, {}]", z1.getFloat(29000000), z1.getFloat(29000001), z1.getFloat(29000002), z1.getFloat(29000003)); //log.info("Sum: {}", z1.sumNumber().doubleValue()); - log.info("Sum2: {}", z2.sumNumber().doubleValue()); +// log.info("Sum2: {}", z2.sumNumber().doubleValue()); INDArray match = Nd4j.getExecutioner().exec(new MatchCondition(z1, Conditions.isNan())); - log.info("NaNs: {}", match); +// log.info("NaNs: {}", match); assertEquals(0.0f, match.getFloat(0), 0.01f); /* @@ -482,14 +481,14 @@ public class RandomTests extends BaseNd4jTest { public void testSum_119() { INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000); val sum = z2.sumNumber().doubleValue(); - log.info("Sum2: {}", sum); +// log.info("Sum2: {}", sum); assertEquals(0.0, sum, 1e-5); } @Test public void testLegacyDistribution1() { NormalDistribution distribution = new NormalDistribution(new DefaultRandom(), 0.0, 1.0); - INDArray z1 = distribution.sample(new int[] {1, 30000000}); + INDArray z1 = distribution.sample(new int[] {1, 1000000}); assertEquals(0.0, z1.meanNumber().doubleValue(), 0.01); assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01); @@ -1372,7 +1371,7 @@ public class RandomTests extends BaseNd4jTest { val array = dist.sample(new int[] {6, 9}); - log.info("Array: {}", array); +// log.info("Array: {}", array); } @Test @@ -1381,7 +1380,7 @@ public class RandomTests extends BaseNd4jTest { val array = dist.sample(new int[] {9, 6}); - log.info("Array: {}", array); +// log.info("Array: {}", array); } @Test @@ -1390,7 +1389,7 @@ public class RandomTests extends BaseNd4jTest { val array = dist.sample(new int[] {9, 9}); - log.info("Array: {}", array); +// log.info("Array: {}", array); } @Test @@ -1399,7 +1398,7 @@ public class RandomTests extends BaseNd4jTest { int numBatches = 1; for( int t=0; t<10; t++ ) { - System.out.println(t); +// System.out.println(t); numBatches = t; List initial = getList(numBatches); @@ -1426,7 +1425,7 @@ public class RandomTests extends BaseNd4jTest { Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.create(DataType.DOUBLE, 100); Nd4j.exec(new BernoulliDistribution(arr, 0.5)); - System.out.println(arr); +// System.out.println(arr); double sum = arr.sumNumber().doubleValue(); assertTrue(String.valueOf(sum), sum > 0.0 && sum < 100.0); } @@ -1473,6 +1472,44 @@ public class RandomTests extends BaseNd4jTest { assertEquals(out1, out2); } + @Test + public void testGamma(){ + Nd4j.getRandom().setSeed(12345); + INDArray shape = Nd4j.createFromArray(new int[] {1,3}); + INDArray alpha = Nd4j.rand(1,3); + val randomGamma = new RandomGamma(shape, alpha, null); + INDArray[] res = Nd4j.exec(randomGamma); + + val randomGamma1 = new RandomGamma(shape, alpha, null); + INDArray[] res1 = Nd4j.exec(randomGamma1); + assertEquals(res[0], res1[0]); + } + + @Test + public void testPoisson(){ + Nd4j.getRandom().setSeed(12345); + INDArray shape = Nd4j.createFromArray(new int[] {1,3}); + INDArray alpha = Nd4j.rand(1,3); + val randomPoisson = new RandomPoisson(shape, alpha); + INDArray[] res = Nd4j.exec(randomPoisson); + + val randomPoisson1 = new RandomPoisson(shape, alpha); + INDArray[] res1 = Nd4j.exec(randomPoisson1); + assertEquals(res[0], res1[0]); + } + + @Test + public void testShuffle(){ + Nd4j.getRandom().setSeed(12345); + INDArray alpha = Nd4j.rand(1,3); + val randomShuffle = new RandomShuffle(alpha); + INDArray[] res = Nd4j.exec(randomShuffle); + + val randomShuffle1 = new RandomShuffle(alpha); + INDArray[] res1 = Nd4j.exec(randomShuffle1); + assertEquals(res[0], res1[0]); + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java index b7516a9f2..7d23a021a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.io.*; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @RunWith(Parameterized.class) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java index 164760dc0..bf7286b91 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java @@ -38,8 +38,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; @Slf4j public class NumpyFormatTests extends BaseNd4jTest { @@ -70,7 +69,7 @@ public class NumpyFormatTests extends BaseNd4jTest { int lastDot = path.lastIndexOf('.'); int lastUnderscore = path.lastIndexOf('_'); String dtype = path.substring(lastUnderscore+1, lastDot); - System.out.println(path + " : " + dtype); +// System.out.println(path + " : " + dtype); DataType dt = DataType.fromNumpy(dtype); //System.out.println(dt); @@ -120,7 +119,7 @@ public class NumpyFormatTests extends BaseNd4jTest { int lastDot = path.lastIndexOf('.'); int lastUnderscore = path.lastIndexOf('_'); String dtype = path.substring(lastUnderscore+1, lastDot); - System.out.println(path + " : " + dtype); +// System.out.println(path + " : " + dtype); DataType dt = DataType.fromNumpy(dtype); //System.out.println(dt); @@ -173,7 +172,7 @@ public class NumpyFormatTests extends BaseNd4jTest { int lastDot = path.lastIndexOf('.'); int lastSlash = Math.max(path.lastIndexOf('/'), path.lastIndexOf('\\')); String dtype = path.substring(lastSlash+1, lastDot); - System.out.println(path + " : " + dtype); +// System.out.println(path + " : " + dtype); DataType dt = DataType.fromNumpy(dtype); //System.out.println(dt); @@ -236,7 +235,7 @@ public class NumpyFormatTests extends BaseNd4jTest { int lastDot = path.lastIndexOf('.'); int lastUnderscore = path.lastIndexOf('_'); String dtype = path.substring(lastUnderscore + 1, lastDot); - System.out.println(path + " : " + dtype); +// System.out.println(path + " : " + dtype); DataType dt = DataType.fromNumpy(dtype); //System.out.println(dt); @@ -322,8 +321,8 @@ public class NumpyFormatTests extends BaseNd4jTest { @Test public void testNumpyBoolean() { INDArray out = Nd4j.createFromNpyFile(new File("c:/Users/raver/Downloads/error2.npy")); - System.out.println(ArrayUtil.toList(ArrayUtil.toInts(out.shape()))); - System.out.println(out); +// System.out.println(ArrayUtil.toList(ArrayUtil.toInts(out.shape()))); +// System.out.println(out); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java index b96ebaca9..1eaac2e60 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java index 521515b5f..aa6ce104b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java @@ -116,10 +116,11 @@ public class NDArrayMathTests extends BaseNd4jTest { INDArray otherTest = Nd4j.linspace(1, 144, 144, DataType.DOUBLE).reshape(6, 3, 2, 2, 2); - System.out.println(otherTest); +// System.out.println(otherTest); INDArray baseArr = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(2, 2, 2); for (int i = 0; i < baseArr.tensorsAlongDimension(0, 1); i++) { - System.out.println(NDArrayMath.sliceOffsetForTensor(i, baseArr, new int[] {2, 2})); +// System.out.println(NDArrayMath.sliceOffsetForTensor(i, baseArr, new int[] {2, 2})); + NDArrayMath.sliceOffsetForTensor(i, baseArr, new int[] {2, 2}); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java index f4f3e67f2..373791a2c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java @@ -126,7 +126,7 @@ public class ShapeTestsC extends BaseNd4jTest { INDArray matrix = Nd4j.create(new double[][] {{1, 2}, {3, 4}}); for (int i = 0; i < matrix.rows(); i++) { INDArray row = matrix.getRow(i); - System.out.println(matrix.getRow(i)); +// System.out.println(matrix.getRow(i)); } matrix.putRow(1, Nd4j.create(new double[] {1, 2})); assertEquals(matrix.getRow(0), matrix.getRow(1)); @@ -187,9 +187,9 @@ public class ShapeTestsC extends BaseNd4jTest { INDArray slice = nd.slice(1, 0); INDArray vector = slice; - for (int i = 0; i < vector.length(); i++) { - System.out.println(vector.getDouble(i)); - } +// for (int i = 0; i < vector.length(); i++) { +// System.out.println(vector.getDouble(i)); +// } assertEquals(Nd4j.create(new double[] {4, 5, 6}), vector); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java index 4a019c251..28e61a9eb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java @@ -34,6 +34,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java index 6f47d00da..2953e2677 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java @@ -94,13 +94,13 @@ public class TADTests extends BaseNd4jTest { } } - log.info("3D TADs:"); +// log.info("3D TADs:"); for (char o : order) { INDArray array = Nd4j.create(new int[] {9, 7, 5, 3}, o); for (int[] shape : dim_3) { Arrays.sort(shape); - log.info("About to do shape: " + Arrays.toString(shape) + " for array of shape " - + array.shapeInfoToString()); +// log.info("About to do shape: " + Arrays.toString(shape) + " for array of shape " +// + array.shapeInfoToString()); INDArray assertion = array.tensorAlongDimension(0, shape); INDArray test = array.tensorAlongDimension(0, shape); assertEquals(assertion, test); @@ -128,10 +128,10 @@ public class TADTests extends BaseNd4jTest { Pair tadBuffersC = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayC, 2, 3); - log.info("Got TADShapeF: {}", Arrays.toString(tadBuffersF.getFirst().asInt()) + " with java " - + javaFTad.shapeInfoDataBuffer()); - log.info("Got TADShapeC: {}", Arrays.toString(tadBuffersC.getFirst().asInt()) + " with java " - + javaCTad.shapeInfoDataBuffer()); +// log.info("Got TADShapeF: {}", Arrays.toString(tadBuffersF.getFirst().asInt()) + " with java " +// + javaFTad.shapeInfoDataBuffer()); +// log.info("Got TADShapeC: {}", Arrays.toString(tadBuffersC.getFirst().asInt()) + " with java " +// + javaCTad.shapeInfoDataBuffer()); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java index a68883b4f..30bdbfb37 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java @@ -87,8 +87,6 @@ public class ConcatTests extends BaseNd4jTest { assertTrue(firstRet.isColumnVector()); INDArray secondRet = Nd4j.concat(1, first, second); assertTrue(secondRet.isRowVector()); - - } @@ -138,7 +136,7 @@ public class ConcatTests extends BaseNd4jTest { assertEquals(exp, concat0); - System.out.println("1------------------------"); +// System.out.println("1------------------------"); //ConcatV2, dim 1 second = Nd4j.linspace(24, 32, 8, DataType.DOUBLE).reshape('c', 2, 1, 4); @@ -148,7 +146,7 @@ public class ConcatTests extends BaseNd4jTest { exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(3), NDArrayIndex.all()}, second); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(4, 6), NDArrayIndex.all()}, third); - System.out.println("2------------------------"); +// System.out.println("2------------------------"); INDArray concat1 = Nd4j.concat(1, first, second, third); @@ -192,7 +190,7 @@ public class ConcatTests extends BaseNd4jTest { INDArray s2 = s.getFirst().assign(second); INDArray t2 = t.getFirst().assign(third); - System.out.println("-------------------------------------------"); +// System.out.println("-------------------------------------------"); INDArray concat0 = Nd4j.concat(0, f2, s2, t2); assertEquals(exp, concat0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java index 806cf4d08..07ef7dcac 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java @@ -108,7 +108,7 @@ public class ConcatTestsC extends BaseNd4jTest { assertEquals(3, result.rows()); assertEquals(10, result.columns()); - System.out.println(result); +// System.out.println(result); for (int x = 0; x < 30; x++) { assertEquals(1f, result.getFloat(x), 0.001f); @@ -124,8 +124,8 @@ public class ConcatTestsC extends BaseNd4jTest { INDArray concat1 = Nd4j.concat(1, a, b); INDArray oneAssertion = Nd4j.create(new double[][] {{1, 2, 1, 2}, {3, 4, 3, 4}}); - System.out.println("Assertion: " + Arrays.toString(oneAssertion.data().asFloat())); - System.out.println("Result: " + Arrays.toString(concat1.data().asFloat())); +// System.out.println("Assertion: " + Arrays.toString(oneAssertion.data().asFloat())); +// System.out.println("Result: " + Arrays.toString(concat1.data().asFloat())); assertEquals(oneAssertion, concat1); @@ -186,7 +186,7 @@ public class ConcatTestsC extends BaseNd4jTest { second = Nd4j.linspace(24, 32, 8, Nd4j.dataType()).reshape('c', 2, 1, 4); for (int i = 0; i < second.tensorsAlongDimension(1); i++) { INDArray secondTad = second.tensorAlongDimension(i, 1); - System.out.println(second.tensorAlongDimension(i, 1)); +// System.out.println(second.tensorAlongDimension(i, 1)); } third = Nd4j.linspace(32, 48, 16).reshape('c', 2, 2, 4); @@ -215,7 +215,7 @@ public class ConcatTestsC extends BaseNd4jTest { @Test(expected = ND4JIllegalStateException.class) public void testConcatVector() { - System.out.println(Nd4j.concat(0, Nd4j.ones(1,1000000), Nd4j.create(1, 1))); + Nd4j.concat(0, Nd4j.ones(1,1000000), Nd4j.create(1, 1)); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java index 6c98c5afb..a54cbbfbd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java index 2483f03e6..7efef4551 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.convolution.Convolution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** @@ -103,7 +104,7 @@ public class PaddingTestsC extends BaseNd4jTest { long outWidth = Convolution.outSize(h, kh, sy, ph, 1, true); long outHeight = Convolution.outSize(w, kw, sx, pw, 1, true); INDArray padded = Nd4j.pad(linspaced, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}); - System.out.println(padded); +// System.out.println(padded); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java index b67f684c7..6c3f911f4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java @@ -33,8 +33,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * @author Adam Gibson @@ -52,7 +51,7 @@ public class IndexingTests extends BaseNd4jTest { @Test public void testGet() { - System.out.println("Testing sub-array put and get with a 3D array ..."); +// System.out.println("Testing sub-array put and get with a 3D array ..."); INDArray arr = Nd4j.linspace(0, 124, 125).reshape(5, 5, 5); @@ -99,13 +98,13 @@ public class IndexingTests extends BaseNd4jTest { INDArray whatToPut = arr.get(whereToGet); assertEquals(subArr_A, whatToPut); - System.out.println(whatToPut); +// System.out.println(whatToPut); INDArrayIndex[] whereToPut = new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all()}; subArr_B.put(whereToPut, whatToPut); assertEquals(subArr_A, subArr_B); - System.out.println("... done"); +// System.out.println("... done"); } /* @@ -154,7 +153,7 @@ public class IndexingTests extends BaseNd4jTest { INDArrayIndex ndi_Slice = NDArrayIndex.point(s); for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { - log.info("Running for ( {}, {} - {} , {} - {} )", s, i, rows, j, cols); +// log.info("Running for ( {}, {} - {} , {} - {} )", s, i, rows, j, cols); INDArrayIndex ndi_I = NDArrayIndex.interval(i, rows); INDArrayIndex ndi_J = NDArrayIndex.interval(j, cols); INDArray aView = A.get(ndi_Slice, NDArrayIndex.all(), NDArrayIndex.all()).get(ndi_I, ndi_J); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java index 9593b5a3b..33a64c291 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java @@ -32,8 +32,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * @author Adam Gibson @@ -99,7 +98,7 @@ public class IndexingTestsC extends BaseNd4jTest { public void testPointIndexes() { INDArray arr = Nd4j.create(DataType.DOUBLE, 4, 3, 2); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all()); - assertArrayEquals(new int[] {4, 2}, get.shape()); + assertArrayEquals(new long[] {4, 2}, get.shape()); INDArray linspaced = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); INDArray assertion = Nd4j.create(new double[][] {{3, 4}, {9, 10}, {15, 16}, {21, 22}}); @@ -108,7 +107,7 @@ public class IndexingTestsC extends BaseNd4jTest { INDArray sliceI = linspacedGet.slice(i); assertEquals(assertion.slice(i), sliceI); } - assertArrayEquals(new int[] {6, 1}, linspacedGet.stride()); + assertArrayEquals(new long[] {6, 1}, linspacedGet.stride()); assertEquals(assertion, linspacedGet); } @@ -131,7 +130,7 @@ public class IndexingTestsC extends BaseNd4jTest { INDArray get = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim), NDArrayIndex.interval(j, sx, jLim)); - assertArrayEquals(new int[] {81, 81, 18, 2}, get.stride()); + assertArrayEquals(new long[] {81, 81, 18, 2}, get.stride()); INDArray assertion = Nd4j.create(new double[] {1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3}, new int[] {1, 1, 4, 4}); assertEquals(assertion, get); @@ -143,7 +142,7 @@ public class IndexingTestsC extends BaseNd4jTest { INDArray assertion2 = Nd4j.create(new double[] {2, 2, 2, 2, 4, 4, 4, 4, 2, 2, 2, 2, 4, 4, 4, 4}, new int[] {1, 1, 4, 4}); - assertArrayEquals(new int[] {81, 81, 18, 2}, get3.stride()); + assertArrayEquals(new long[] {81, 81, 18, 2}, get3.stride()); assertEquals(assertion2, get3); @@ -154,7 +153,7 @@ public class IndexingTestsC extends BaseNd4jTest { j = 1; INDArray get2 = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim), NDArrayIndex.interval(j, sx, jLim)); - assertArrayEquals(new int[] {81, 81, 18, 2}, get2.stride()); + assertArrayEquals(new long[] {81, 81, 18, 2}, get2.stride()); assertEquals(assertion, get2); @@ -171,29 +170,29 @@ public class IndexingTestsC extends BaseNd4jTest { } INDArray first10a = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 10)); - assertArrayEquals(first10a.shape(), new int[] {10}); + assertArrayEquals(first10a.shape(), new long[] {10}); for (int i = 0; i < 10; i++) assertTrue(first10a.getDouble(i) == i); INDArray first10b = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 10)); - assertArrayEquals(first10b.shape(), new int[] {10}); + assertArrayEquals(first10b.shape(), new long[] {10}); for (int i = 0; i < 10; i++) assertTrue(first10b.getDouble(i) == i); INDArray last10a = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(20, 30)); - assertArrayEquals(last10a.shape(), new int[] {10}); + assertArrayEquals(last10a.shape(), new long[] {10}); for (int i = 0; i < 10; i++) assertEquals(i+20, last10a.getDouble(i), 1e-6); INDArray last10b = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(20, 30)); - assertArrayEquals(last10b.shape(), new int[] {10}); + assertArrayEquals(last10b.shape(), new long[] {10}); for (int i = 0; i < 10; i++) assertTrue(last10b.getDouble(i) == 20 + i); } @Test public void testGet() { - System.out.println("Testing sub-array put and get with a 3D array ..."); +// System.out.println("Testing sub-array put and get with a 3D array ..."); INDArray arr = Nd4j.linspace(0, 124, 125).reshape(5, 5, 5); @@ -238,14 +237,14 @@ public class IndexingTestsC extends BaseNd4jTest { INDArrayIndex[] whereToGet = new INDArrayIndex[] {ndi_Slice, ndi_I, ndi_J}; INDArray whatToPut = arr.get(whereToGet); - System.out.println(whatToPut); +// System.out.println(whatToPut); INDArrayIndex[] whereToPut = new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all()}; subArr_B.put(whereToPut, whatToPut); assertEquals(subArr_A, subArr_B); - System.out.println("... done"); +// System.out.println("... done"); } @Test @@ -286,7 +285,7 @@ public class IndexingTestsC extends BaseNd4jTest { INDArrayIndex ndi_Slice = NDArrayIndex.point(s); for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { - log.info("Running for ( {}, {} - {} , {} - {} )", s, i, rows, j, cols); +// log.info("Running for ( {}, {} - {} , {} - {} )", s, i, rows, j, cols); INDArrayIndex ndi_I = NDArrayIndex.interval(i, rows); INDArrayIndex ndi_J = NDArrayIndex.interval(j, cols); INDArray aView = A.get(ndi_Slice, NDArrayIndex.all(), NDArrayIndex.all()).get(ndi_I, ndi_J); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java index 275e9dcd6..9f4d9ec9b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java @@ -65,7 +65,7 @@ public class LeadingAndTrailingOnes extends BaseNd4jTest { INDArray arr = Nd4j.create(1, 10, 1, 1); arr.assign(1); arr.toString(); - System.out.println(arr); +// System.out.println(arr); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java index d5f5ac361..7c95b9bfe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java @@ -41,18 +41,18 @@ public class LeadingAndTrailingOnesC extends BaseNd4jTest { public void testCreateLeadingAndTrailingOnes() { INDArray arr = Nd4j.create(1, 10, 1, 1); arr.assign(1); - System.out.println(arr); +// System.out.println(arr); } @Test public void testMatrix() { INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray slice1 = arr.slice(1); - System.out.println(arr.slice(1)); +// System.out.println(arr.slice(1)); INDArray oneInMiddle = Nd4j.linspace(1, 4, 4).reshape(2, 1, 2); INDArray otherSlice = oneInMiddle.slice(1); assertEquals(2, otherSlice.offset()); - System.out.println(otherSlice); +// System.out.println(otherSlice); INDArray twoOnesInMiddle = Nd4j.linspace(1, 4, 4).reshape(2, 1, 1, 2); INDArray sub = twoOnesInMiddle.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()); @@ -65,21 +65,7 @@ public class LeadingAndTrailingOnesC extends BaseNd4jTest { INDArray tensor = Nd4j.linspace(1, 144, 144).reshape(2, 2, 1, 1, 6, 6); INDArray tensorSlice1 = tensor.slice(1); INDArray tensorSlice1Slice1 = tensorSlice1.slice(1); - System.out.println(tensor); - } - - @Test - public void testOnesInMiddleTensor() { - INDArray im2colAssertion = Nd4j.create(new double[] {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 7.0, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 9.0, 10.0, - 0.0, 0.0, 0.0, 0.0, 11.0, 12.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 14.0, 0.0, 0.0, - 0.0, 0.0, 15.0, 16.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, - new int[] {2, 2, 1, 1, 6, 6}); - System.out.println(im2colAssertion); +// System.out.println(tensor); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java index 4c40df4e3..bd698b9fd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assume.assumeNotNull; @@ -66,7 +67,7 @@ public class ReshapeTests extends BaseNd4jTest { double delta = 1e-1; INDArray arr = Nd4j.create(1, 3); INDArray reshaped = arr.reshape('f', 3, 1); - assertArrayEquals(new int[] {3, 1}, reshaped.shape()); + assertArrayEquals(new long[] {3, 1}, reshaped.shape()); assertEquals(0.0, reshaped.getDouble(1), delta); assertEquals(0.0, reshaped.getDouble(2), delta); log.info("Reshaped: {}", reshaped.shapeInfoDataBuffer().asInt()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java index 0e696c884..b54203af2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** @@ -43,7 +44,8 @@ public class SlicingTestsC extends BaseNd4jTest { @Test public void testSliceRowVector() { INDArray arr = Nd4j.zeros(5); - System.out.println(arr.slice(1)); +// System.out.println(arr.slice(1)); + arr.slice(1); } @@ -51,10 +53,10 @@ public class SlicingTestsC extends BaseNd4jTest { public void testSliceAssertion() { INDArray arr = Nd4j.linspace(1, 30, 30).reshape(3, 5, 2); INDArray firstRow = arr.slice(0).slice(0); - for (int i = 0; i < firstRow.length(); i++) { - System.out.println(firstRow.getDouble(i)); - } - System.out.println(firstRow); +// for (int i = 0; i < firstRow.length(); i++) { +// System.out.println(firstRow.getDouble(i)); +// } +// System.out.println(firstRow); } @Test @@ -64,19 +66,19 @@ public class SlicingTestsC extends BaseNd4jTest { INDArray sliceZero = arr.slice(0); for (int i = 0; i < sliceZero.rows(); i++) { INDArray row = sliceZero.slice(i); - for (int j = 0; j < row.length(); j++) { - System.out.println(row.getDouble(j)); - } - System.out.println(row); +// for (int j = 0; j < row.length(); j++) { +// System.out.println(row.getDouble(j)); +// } +// System.out.println(row); } INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, new int[] {5, 2}); for (int i = 0; i < assertion.rows(); i++) { INDArray row = assertion.slice(i); - for (int j = 0; j < row.length(); j++) { - System.out.println(row.getDouble(j)); - } - System.out.println(row); +// for (int j = 0; j < row.length(); j++) { +// System.out.println(row.getDouble(j)); +// } +// System.out.println(row); } assertArrayEquals(new long[] {5, 2}, sliceZero.shape()); assertEquals(assertion, sliceZero); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java index 33d24cd68..39f5c10a9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java @@ -36,6 +36,8 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.Arrays; import java.util.stream.LongStream; +import static org.junit.Assert.assertArrayEquals; + /** * @author raver119@gmail.com */ @@ -84,15 +86,15 @@ public class SortCooTests extends BaseNd4jTest { DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices); DataBuffer val = Nd4j.createBuffer(values); - log.info("Old indices: {}", Arrays.toString(idx.asInt())); +// log.info("Old indices: {}", Arrays.toString(idx.asInt())); NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(), val.addressPointer(), 4, 3); - log.info("New indices: {}", Arrays.toString(idx.asInt())); +// log.info("New indices: {}", Arrays.toString(idx.asInt())); - assertArrayEquals(expIndices, idx.asInt()); + assertArrayEquals(expIndices, idx.asLong()); assertArrayEquals(expValues, val.asDouble(), 1e-5); } @@ -121,7 +123,7 @@ public class SortCooTests extends BaseNd4jTest { NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(), val.addressPointer(), 3, 3); - assertArrayEquals(expIndices, idx.asInt()); + assertArrayEquals(expIndices, idx.asLong()); assertArrayEquals(expValues, val.asDouble(), 1e-5); } @@ -277,7 +279,7 @@ public class SortCooTests extends BaseNd4jTest { // just check the indices. sortSparseCooIndicesSort1 and sortSparseCooIndicesSort2 checks that // indices and values are both swapped. This test just makes sure index sort works for larger arrays. - assertArrayEquals(expIndices, idx.asInt()); + assertArrayEquals(expIndices, idx.asLong()); } @Override public char ordering() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java index baef37c13..daa80295f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java index a2d1cf979..e5669595f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java index 98f2b7aa8..7e6e73289 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java @@ -44,7 +44,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr0 = Nd4jCommonValidator.isValidFile(fNonExistent); assertFalse(vr0.isValid()); assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); - System.out.println(vr0.toString()); +// System.out.println(vr0.toString()); //Test empty file: File fEmpty = new File(f, "0.bin"); @@ -52,7 +52,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr1 = Nd4jCommonValidator.isValidFile(fEmpty); assertFalse(vr1.isValid()); assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); - System.out.println(vr1.toString()); +// System.out.println(vr1.toString()); //Test directory File directory = new File(f, "dir"); @@ -61,14 +61,14 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr2 = Nd4jCommonValidator.isValidFile(directory); assertFalse(vr2.isValid()); assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); - System.out.println(vr2.toString()); +// System.out.println(vr2.toString()); //Test valid non-empty file - valid File f3 = new File(f, "1.txt"); FileUtils.writeStringToFile(f3, "Test", StandardCharsets.UTF_8); ValidationResult vr3 = Nd4jCommonValidator.isValidFile(f3); assertTrue(vr3.isValid()); - System.out.println(vr3.toString()); +// System.out.println(vr3.toString()); } @Test @@ -80,7 +80,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr0 = Nd4jCommonValidator.isValidZipFile(fNonExistent, false); assertFalse(vr0.isValid()); assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); - System.out.println(vr0.toString()); +// System.out.println(vr0.toString()); //Test empty zip: File fEmpty = new ClassPathResource("validation/empty_zip.zip").getFile(); @@ -88,7 +88,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr1 = Nd4jCommonValidator.isValidZipFile(fEmpty, false); assertFalse(vr1.isValid()); assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); - System.out.println(vr1.toString()); +// System.out.println(vr1.toString()); //Test directory (not zip file) File directory = new File(f, "dir"); @@ -97,7 +97,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr2 = Nd4jCommonValidator.isValidFile(directory); assertFalse(vr2.isValid()); assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); - System.out.println(vr2.toString()); +// System.out.println(vr2.toString()); //Test non-empty zip - valid File f3 = new File(f, "1.zip"); @@ -108,7 +108,7 @@ public class ValidationUtilTests extends BaseNd4jTest { } ValidationResult vr3 = Nd4jCommonValidator.isValidZipFile(f3, false); assertTrue(vr3.isValid()); - System.out.println(vr3.toString()); +// System.out.println(vr3.toString()); //Test non-empty zip - but missing required entries ValidationResult vr4 = Nd4jCommonValidator.isValidZipFile(f3, false, "content.txt", "someFile1.bin", "someFile2.bin"); @@ -117,7 +117,7 @@ public class ValidationUtilTests extends BaseNd4jTest { String s = vr4.getIssues().get(0); assertTrue(s, s.contains("someFile1.bin") && s.contains("someFile2.bin")); assertFalse(s, s.contains("content.txt")); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); } @@ -131,7 +131,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr0.isValid()); assertEquals("INDArray Text File", vr0.getFormatType()); assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); - System.out.println(vr0.toString()); +// System.out.println(vr0.toString()); //Test empty file: File fEmpty = new File(f, "empty.txt"); @@ -141,7 +141,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("INDArray Text File", vr1.getFormatType()); assertFalse(vr1.isValid()); assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); - System.out.println(vr1.toString()); +// System.out.println(vr1.toString()); //Test directory (not zip file) File directory = new File(f, "dir"); @@ -151,7 +151,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("INDArray Text File", vr2.getFormatType()); assertFalse(vr2.isValid()); assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); - System.out.println(vr2.toString()); +// System.out.println(vr2.toString()); //Test non-INDArray format: File fText = new File(f, "text.txt"); @@ -161,7 +161,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); assertTrue(s, s.contains("text") && s.contains("INDArray") && s.contains("corrupt")); - System.out.println(vr3.toString()); +// System.out.println(vr3.toString()); //Test corrupted txt format: File fValid = new File(f, "valid.txt"); @@ -179,7 +179,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); assertTrue(s, s.contains("text") && s.contains("INDArray") && s.contains("corrupt")); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); //Test valid npz format: @@ -188,7 +188,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertTrue(vr5.isValid()); assertNull(vr5.getIssues()); assertNull(vr5.getException()); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); } @@ -204,7 +204,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr0.isValid()); assertEquals("Numpy .npy File", vr0.getFormatType()); assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); - System.out.println(vr0.toString()); +// System.out.println(vr0.toString()); //Test empty file: File fEmpty = new File(f, "empty.npy"); @@ -214,7 +214,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy .npy File", vr1.getFormatType()); assertFalse(vr1.isValid()); assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); - System.out.println(vr1.toString()); +// System.out.println(vr1.toString()); //Test directory (not zip file) File directory = new File(f, "dir"); @@ -224,7 +224,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy .npy File", vr2.getFormatType()); assertFalse(vr2.isValid()); assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); - System.out.println(vr2.toString()); +// System.out.println(vr2.toString()); //Test non-numpy format: File fText = new File(f, "text.txt"); @@ -234,7 +234,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); assertTrue(s, s.contains("npy") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); - System.out.println(vr3.toString()); +// System.out.println(vr3.toString()); //Test corrupted npy format: File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile(); @@ -250,7 +250,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); assertTrue(s, s.contains("npy") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); //Test valid npy format: @@ -259,7 +259,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertTrue(vr5.isValid()); assertNull(vr5.getIssues()); assertNull(vr5.getException()); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); } @Test @@ -273,7 +273,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr0.isValid()); assertEquals("Numpy .npz File", vr0.getFormatType()); assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); - System.out.println(vr0.toString()); +// System.out.println(vr0.toString()); //Test empty file: File fEmpty = new File(f, "empty.npz"); @@ -283,7 +283,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy .npz File", vr1.getFormatType()); assertFalse(vr1.isValid()); assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); - System.out.println(vr1.toString()); +// System.out.println(vr1.toString()); //Test directory (not zip file) File directory = new File(f, "dir"); @@ -293,7 +293,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy .npz File", vr2.getFormatType()); assertFalse(vr2.isValid()); assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); - System.out.println(vr2.toString()); +// System.out.println(vr2.toString()); //Test non-numpy format: File fText = new File(f, "text.txt"); @@ -303,7 +303,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); assertTrue(s, s.contains("npz") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); - System.out.println(vr3.toString()); +// System.out.println(vr3.toString()); //Test corrupted npz format: File fValid = new ClassPathResource("numpy_arrays/npz/float32.npz").getFile(); @@ -319,7 +319,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); assertTrue(s, s.contains("npz") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); //Test valid npz format: @@ -328,7 +328,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertTrue(vr5.isValid()); assertNull(vr5.getIssues()); assertNull(vr5.getException()); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); } @Test @@ -341,7 +341,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr0.isValid()); assertEquals("Numpy text file", vr0.getFormatType()); assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); - System.out.println(vr0.toString()); +// System.out.println(vr0.toString()); //Test empty file: File fEmpty = new File(f, "empty.txt"); @@ -351,7 +351,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy text file", vr1.getFormatType()); assertFalse(vr1.isValid()); assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); - System.out.println(vr1.toString()); +// System.out.println(vr1.toString()); //Test directory (not zip file) File directory = new File(f, "dir"); @@ -361,7 +361,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy text file", vr2.getFormatType()); assertFalse(vr2.isValid()); assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); - System.out.println(vr2.toString()); +// System.out.println(vr2.toString()); //Test non-numpy format: File fText = new File(f, "text.txt"); @@ -371,7 +371,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); assertTrue(s, s.contains("text") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); - System.out.println(vr3.toString()); +// System.out.println(vr3.toString()); //Test corrupted txt format: File fValid = new ClassPathResource("numpy_arrays/txt/arange_3,4_float32.txt").getFile(); @@ -387,7 +387,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); assertTrue(s, s.contains("text") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); //Test valid npz format: @@ -396,7 +396,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertTrue(vr5.isValid()); assertNull(vr5.getIssues()); assertNull(vr5.getException()); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); } @Test @@ -418,7 +418,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr0.isValid()); assertEquals("SameDiff FlatBuffers file", vr0.getFormatType()); assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); - System.out.println(vr0.toString()); +// System.out.println(vr0.toString()); //Test empty file: File fEmpty = new File(f, "empty.fb"); @@ -428,7 +428,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("SameDiff FlatBuffers file", vr1.getFormatType()); assertFalse(vr1.isValid()); assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); - System.out.println(vr1.toString()); +// System.out.println(vr1.toString()); //Test directory (not zip file) File directory = new File(f, "dir"); @@ -438,7 +438,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("SameDiff FlatBuffers file", vr2.getFormatType()); assertFalse(vr2.isValid()); assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); - System.out.println(vr2.toString()); +// System.out.println(vr2.toString()); //Test non-flatbuffers File fText = new File(f, "text.fb"); @@ -448,7 +448,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); assertTrue(s, s.contains("FlatBuffers") && s.contains("SameDiff") && s.contains("corrupt")); - System.out.println(vr3.toString()); +// System.out.println(vr3.toString()); //Test corrupted flatbuffers format: byte[] fbBytes = FileUtils.readFileToByteArray(fOrig); @@ -463,7 +463,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); assertTrue(s, s.contains("FlatBuffers") && s.contains("SameDiff") && s.contains("corrupt")); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); //Test valid npz format: @@ -472,7 +472,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertTrue(vr5.isValid()); assertNull(vr5.getIssues()); assertNull(vr5.getException()); - System.out.println(vr4.toString()); +// System.out.println(vr4.toString()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index d98e9218e..e9b3e100f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.workspace; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.junit.After; import org.junit.Before; import org.junit.Ignore; @@ -29,6 +30,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.*; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -176,6 +178,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { @Test public void testLeverageTo2() { + val exp = Nd4j.scalar(15.0); try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(loopOverTimeConfig, "EXT")) { INDArray array1 = Nd4j.create(new double[] {1f, 2f, 3f, 4f, 5f}); @@ -192,6 +195,10 @@ public class BasicWorkspaceTests extends BaseNd4jTest { assertEquals(0, wsOne.getCurrentSize()); assertEquals(15f, array3.sumNumber().floatValue(), 0.01f); + + array2.assign(0); + + assertEquals(15f, array3.sumNumber().floatValue(), 0.01f); } try (Nd4jWorkspace wsTwo = @@ -296,8 +303,8 @@ public class BasicWorkspaceTests extends BaseNd4jTest { INDArray delta = Nd4j.create(new int[] {50, 64, 8, 8}, new int[] {64, 3200, 8, 1}, 'c'); delta = delta.permute(1, 0, 2, 3); - assertArrayEquals(new int[] {64, 50, 8, 8}, delta.shape()); - assertArrayEquals(new int[] {3200, 64, 8, 1}, delta.stride()); + assertArrayEquals(new long[] {64, 50, 8, 8}, delta.shape()); + assertArrayEquals(new long[] {3200, 64, 8, 1}, delta.stride()); INDArray delta2d = Shape.newShapeNoCopy(delta, new int[] {outDepth, miniBatch * outH * outW}, false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java index fc48044ea..60ed58b76 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java @@ -53,7 +53,7 @@ public class CyclicWorkspaceTests extends BaseNd4jTest { val fArray = Nd4j.create(fShape).assign(e); val lArray = Nd4j.create(lShape).assign(e); - log.info("Current offset: {}; Current size: {};", ws.getCurrentOffset(), ws.getCurrentSize()); +// log.info("Current offset: {}; Current size: {};", ws.getCurrentOffset(), ws.getCurrentSize()); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java index ce7a899a5..8f389697d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java @@ -70,7 +70,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { } Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1"); - workspace.enableDebug(true); +// workspace.enableDebug(true); assertEquals(0, workspace.getStepNumber()); @@ -172,7 +172,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "WS1"); - workspace.enableDebug(true); +// workspace.enableDebug(true); try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { Nd4j.create(500); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java index 764fcbd23..ddfffef17 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java @@ -17,9 +17,10 @@ package org.nd4j.systeminfo; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.systeminfo.SystemInfo; -public class TestSystemInfo { +public class TestSystemInfo extends BaseND4JTest { @Test public void testSystemInfo(){ SystemInfo.printSystemInfo(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml b/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml index 63b2ca84e..9d7518dfb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml @@ -34,7 +34,7 @@ - + diff --git a/nd4j/nd4j-buffer/pom.xml b/nd4j/nd4j-buffer/pom.xml deleted file mode 100644 index 72869b16c..000000000 --- a/nd4j/nd4j-buffer/pom.xml +++ /dev/null @@ -1,154 +0,0 @@ - - - - - nd4j - org.nd4j - 1.0.0-SNAPSHOT - - 4.0.0 - - nd4j-buffer - jar - - nd4j-buffer - - - - org.apache.maven.plugins - maven-compiler-plugin - - 7 - 7 - - - - - - - - - - linux - - linux - - - linux - - - - macosx - - mac os x - - - macosx - - - - windows - - windows - - - windows - - - - i386 - - i386 - - - x86_64 - - - - i486 - - i486 - - - x86_64 - - - - i586 - - i586 - - - x86_64 - - - - i686 - - i686 - - - x86_64 - - - - x86 - - x86 - - - x86_64 - - - - amd64 - - amd64 - - - x86_64 - - - - x86-64 - - x86-64 - - - x86_64 - - - - testresources - - - - - - - - org.nd4j - nd4j-context - ${project.version} - - - org.bytedeco - javacpp - ${javacpp.version} - - - diff --git a/nd4j/nd4j-common-tests/pom.xml b/nd4j/nd4j-common-tests/pom.xml new file mode 100644 index 000000000..a242c6644 --- /dev/null +++ b/nd4j/nd4j-common-tests/pom.xml @@ -0,0 +1,42 @@ + + + + nd4j + org.nd4j + 1.0.0-SNAPSHOT + + 4.0.0 + + nd4j-common-tests + + + + junit + junit + provided + + + org.nd4j + nd4j-api + ${project.version} + + + + + + + nd4j-tests-cpu + + + nd4j-tests-cuda + + + testresources + + + nd4j-testresources + + + \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/BaseDL4JTest.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java similarity index 66% rename from deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/BaseDL4JTest.java rename to nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java index ef26f2848..ae2f56273 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/BaseDL4JTest.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,7 +15,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j; +package org.nd4j; import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.Pointer; @@ -22,7 +23,9 @@ import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; -import org.nd4j.linalg.api.buffer.DataBuffer; +import org.junit.rules.Timeout; +import org.nd4j.base.Preconditions; +import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; @@ -34,15 +37,26 @@ import java.util.List; import java.util.Map; import java.util.Properties; +import static org.junit.Assume.assumeTrue; + @Slf4j -public class BaseDL4JTest { +public abstract class BaseND4JTest { @Rule public TestName name = new TestName(); + @Rule + public Timeout timeout = Timeout.millis(getTimeoutMilliseconds()); protected long startTime; protected int threadCountBefore; + /** + * Override this method to set the default timeout for methods in the test class + */ + public long getTimeoutMilliseconds(){ + return 30000; + } + /** * Override this to set the profiling mode for the tests defined in the child class */ @@ -57,16 +71,63 @@ public class BaseDL4JTest { return DataType.DOUBLE; } + /** + * Override this to set the datatype of the tests defined in the child class + */ public DataType getDefaultFPDataType(){ return getDataType(); } + private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors(); + + /** + * Override this to specify the number of threads for C++ execution, via + * {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)} + * @return Number of threads to use for C++ op execution + */ + public int numThreads(){ + return DEFAULT_THREADS; + } + + protected Boolean integrationTest; + + /** + * @return True if integration tests maven profile is enabled, false otherwise. + */ + public boolean isIntegrationTests(){ + if(integrationTest == null){ + String prop = System.getenv("DL4J_INTEGRATION_TESTS"); + integrationTest = Boolean.parseBoolean(prop); + } + return integrationTest; + } + + /** + * Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled. + * This can be used to dynamically skip integration tests when the integration test profile is not enabled. + * Note that the integration test profile is not enabled by default - "integration-tests" profile + */ + public void skipUnlessIntegrationTests(){ + assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests()); + } + @Before public void beforeTest(){ log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); + //Suppress ND4J initialization - don't need this logged for every test... + System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false"); + System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true"); Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); + Nd4j.getExecutioner().enableDebugMode(false); + Nd4j.getExecutioner().enableVerboseMode(false); + int numThreads = numThreads(); + Preconditions.checkState(numThreads > 0, "Number of threads must be > 0"); + if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) { + Nd4j.getEnvironment().setMaxMasterThreads(numThreads); + } startTime = System.currentTimeMillis(); threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java b/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java index 912c7f1f1..14401d691 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java @@ -31,6 +31,13 @@ public class ND4JSystemProperties { * initialization information */ public static final String LOG_INITIALIZATION = "org.nd4j.log.initialization"; + + /** + * Applicability: nd4j-native when running non-AVX binary on an AVX compatible CPU
+ * Description: Set to true to avoid logging AVX warnings (i.e., running generic x86 binaries on an AVX2 system) + */ + public static final String ND4J_IGNORE_AVX = "org.nd4j.avx.ignore"; + /** * Applicability: Always
* Description: This system property defines the maximum amount of off-heap memory that can be used. diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java index e1408e298..35e5607a2 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java @@ -2071,9 +2071,10 @@ public class ArrayUtil { return new boolean[0]; boolean[] ret = new boolean[arr.length * arr[0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[i].length; j++) - ret[count++] = arr[i][j]; + for (int i = 0; i < arr.length; i++) { + System.arraycopy(arr[i], 0, ret, count, arr[i].length); + count += arr[i].length; + } return ret; } @@ -2083,11 +2084,12 @@ public class ArrayUtil { boolean[] ret = new boolean[arr.length * arr[0].length * arr[0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) { - ret[count++] = arr[i][j][k]; - } + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + System.arraycopy(arr[i][j], 0, ret, count, arr[0][0].length); + count += arr[0][0].length; + } + } return ret; } @@ -2096,24 +2098,27 @@ public class ArrayUtil { return new float[0]; float[] ret = new float[arr.length * arr[0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[i].length; j++) - ret[count++] = arr[i][j]; + for (int i = 0; i < arr.length; i++) { + System.arraycopy(arr[i], 0, ret, count, arr[i].length); + count += arr[i].length; + } return ret; } public static float[] flatten(float[][][] arr) { - if(arr.length == 0 || arr[0].length == 0 || arr[0][0].length == 0) + if (arr.length == 0 || arr[0].length == 0 || arr[0][0].length == 0) return new float[0]; float[] ret = new float[arr.length * arr[0].length * arr[0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) { - ret[count++] = arr[i][j][k]; - } + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + System.arraycopy(arr[i][j], 0, ret, count, arr[0][0].length); + count += arr[0][0].length; + } + } + return ret; } @@ -2123,11 +2128,12 @@ public class ArrayUtil { double[] ret = new double[arr.length * arr[0].length * arr[0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) { - ret[count++] = arr[i][j][k]; - } + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + System.arraycopy(arr[i][j], 0, ret, count, arr[0][0].length); + count += arr[0][0].length; + } + } return ret; } @@ -2137,11 +2143,12 @@ public class ArrayUtil { int[] ret = new int[arr.length * arr[0].length * arr[0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) { - ret[count++] = arr[i][j][k]; - } + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + System.arraycopy(arr[i][j], 0, ret, count, arr[0][0].length); + count += arr[0][0].length; + } + } return ret; } @@ -2151,11 +2158,12 @@ public class ArrayUtil { val ret = new short[arr.length * arr[0].length * arr[0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) { - ret[count++] = arr[i][j][k]; - } + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + System.arraycopy(arr[i][j], 0, ret, count, arr[0][0].length); + count += arr[0][0].length; + } + } return ret; } @@ -2165,11 +2173,12 @@ public class ArrayUtil { val ret = new byte[arr.length * arr[0].length * arr[0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) { - ret[count++] = arr[i][j][k]; - } + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + System.arraycopy(arr[i][j], 0, ret, count, arr[0][0].length); + count += arr[0][0].length; + } + } return ret; } @@ -2177,11 +2186,14 @@ public class ArrayUtil { val ret = new long[arr.length * arr[0].length * arr[0][0].length * arr[0][0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) - for (int m = 0; m < arr[0][0][0].length; m++) - ret[count++] = arr[i][j][k][m]; + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + for (int k = 0; k < arr[0][0].length; k++) { + System.arraycopy(arr[i][j][k], 0, ret, count, arr[0][0][0].length); + count += arr[0][0][0].length; + } + } + } return ret; } @@ -2190,11 +2202,14 @@ public class ArrayUtil { val ret = new short[arr.length * arr[0].length * arr[0][0].length * arr[0][0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) - for (int m = 0; m < arr[0][0][0].length; m++) - ret[count++] = arr[i][j][k][m]; + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + for (int k = 0; k < arr[0][0].length; k++) { + System.arraycopy(arr[i][j][k], 0, ret, count, arr[0][0][0].length); + count += arr[0][0][0].length; + } + } + } return ret; } @@ -2203,11 +2218,14 @@ public class ArrayUtil { val ret = new byte[arr.length * arr[0].length * arr[0][0].length * arr[0][0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) - for (int m = 0; m < arr[0][0][0].length; m++) - ret[count++] = arr[i][j][k][m]; + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + for (int k = 0; k < arr[0][0].length; k++) { + System.arraycopy(arr[i][j][k], 0, ret, count, arr[0][0][0].length); + count += arr[0][0][0].length; + } + } + } return ret; } @@ -2216,11 +2234,14 @@ public class ArrayUtil { val ret = new boolean[arr.length * arr[0].length * arr[0][0].length * arr[0][0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) - for (int m = 0; m < arr[0][0][0].length; m++) - ret[count++] = arr[i][j][k][m]; + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + for (int k = 0; k < arr[0][0].length; k++) { + System.arraycopy(arr[i][j][k], 0, ret, count, arr[0][0][0].length); + count += arr[0][0][0].length; + } + } + } return ret; } @@ -2229,11 +2250,14 @@ public class ArrayUtil { float[] ret = new float[arr.length * arr[0].length * arr[0][0].length * arr[0][0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) - for (int m = 0; m < arr[0][0][0].length; m++) - ret[count++] = arr[i][j][k][m]; + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + for (int k = 0; k < arr[0][0].length; k++) { + System.arraycopy(arr[i][j][k], 0, ret, count, arr[0][0][0].length); + count += arr[0][0][0].length; + } + } + } return ret; } @@ -2242,11 +2266,14 @@ public class ArrayUtil { double[] ret = new double[arr.length * arr[0].length * arr[0][0].length * arr[0][0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) - for (int m = 0; m < arr[0][0][0].length; m++) - ret[count++] = arr[i][j][k][m]; + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + for (int k = 0; k < arr[0][0].length; k++) { + System.arraycopy(arr[i][j][k], 0, ret, count, arr[0][0][0].length); + count += arr[0][0][0].length; + } + } + } return ret; } @@ -2255,11 +2282,14 @@ public class ArrayUtil { int[] ret = new int[arr.length * arr[0].length * arr[0][0].length * arr[0][0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) - for (int m = 0; m < arr[0][0][0].length; m++) - ret[count++] = arr[i][j][k][m]; + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + for (int k = 0; k < arr[0][0].length; k++) { + System.arraycopy(arr[i][j][k], 0, ret, count, arr[0][0][0].length); + count += arr[0][0][0].length; + } + } + } return ret; } @@ -2271,11 +2301,8 @@ public class ArrayUtil { int[] ret = new int[arr.length * arr[0].length]; int count = 0; for (int i = 0; i < arr.length; i++) { - if (arr[i].length != arr[0].length) - throw new IllegalStateException("Length of all rows must be equal"); - - for (int j = 0; j < arr[i].length; j++) - ret[count++] = arr[i][j]; + System.arraycopy(arr[i], 0, ret, count, arr[i].length); + count += arr[i].length; } return ret; } @@ -2285,9 +2312,10 @@ public class ArrayUtil { return new short[0]; val ret = new short[arr.length * arr[0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[i].length; j++) - ret[count++] = arr[i][j]; + for (int i = 0; i < arr.length; i++) { + System.arraycopy(arr[i], 0, ret, count, arr[i].length); + count += arr[i].length; + } return ret; } @@ -2297,34 +2325,21 @@ public class ArrayUtil { val ret = new byte[arr.length * arr[0].length]; int count = 0; for (int i = 0; i < arr.length; i++) { - if (arr[i].length != arr[0].length) - throw new IllegalStateException("Length of all rows must be equal"); - - for (int j = 0; j < arr[i].length; j++) - ret[count++] = arr[i][j]; + System.arraycopy(arr[i], 0, ret, count, arr[i].length); + count += arr[i].length; } return ret; } - /* - public static boolean[] flatten(boolean[][] arr) { - boolean[] ret = new boolean[arr.length * arr[0].length]; - int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[i].length; j++) - ret[count++] = arr[i][j]; - return ret; - } - */ - public static long[] flatten(long[][] arr) { if(arr.length == 0 || arr[0].length == 0 ) return new long[0]; long[] ret = new long[arr.length * arr[0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[i].length; j++) - ret[count++] = arr[i][j]; + for (int i = 0; i < arr.length; i++) { + System.arraycopy(arr[i], 0, ret, count, arr[i].length); + count += arr[i].length; + } return ret; } @@ -2334,11 +2349,12 @@ public class ArrayUtil { long[] ret = new long[arr.length * arr[0].length * arr[0][0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[0].length; j++) - for (int k = 0; k < arr[0][0].length; k++) { - ret[count++] = arr[i][j][k]; - } + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr[0].length; j++) { + System.arraycopy(arr[i][j], 0, ret, count, arr[0][0].length); + count += arr[0][0].length; + } + } return ret; } @@ -2354,9 +2370,10 @@ public class ArrayUtil { return new double[0]; double[] ret = new double[arr.length * arr[0].length]; int count = 0; - for (int i = 0; i < arr.length; i++) - for (int j = 0; j < arr[i].length; j++) - ret[count++] = arr[i][j]; + for (int i = 0; i < arr.length; i++) { + System.arraycopy(arr[i], 0, ret, count, arr[i].length); + count += arr[i].length; + } return ret; } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java b/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java index d9566aabe..f0c6ef318 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java @@ -24,8 +24,6 @@ import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.nd4j.base.Preconditions; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.*; import java.util.ArrayList; @@ -56,6 +54,8 @@ public class ArchiveUtils { File target = new File(file); if (!target.exists()) throw new IllegalArgumentException("Archive doesnt exist"); + if (!new File(dest).exists()) + new File(dest).mkdirs(); FileInputStream fin = new FileInputStream(target); int BUFFER = 2048; byte data[] = new byte[BUFFER]; diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/resources/TestArchiveUtils.java b/nd4j/nd4j-common/src/test/java/org/nd4j/resources/TestArchiveUtils.java new file mode 100644 index 000000000..9a36d7fe3 --- /dev/null +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/resources/TestArchiveUtils.java @@ -0,0 +1,51 @@ +package org.nd4j.resources; + +import org.apache.commons.io.FileUtils; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.util.ArchiveUtils; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; + +public class TestArchiveUtils { + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + @Test + public void testUnzipFileTo() throws IOException { + //random txt file + File dir = testDir.newFolder(); + String content = "test file content"; + String path = "myDir/myTestFile.txt"; + File testFile = new File(dir, path); + testFile.getParentFile().mkdir(); + FileUtils.writeStringToFile(testFile, content, StandardCharsets.UTF_8); + + //zip it as test.zip + File zipFile = new File(testFile.getParentFile(),"test.zip"); + FileOutputStream fos = new FileOutputStream(zipFile); + ZipOutputStream zipOut = new ZipOutputStream(fos); + FileInputStream fis = new FileInputStream(testFile); + ZipEntry zipEntry = new ZipEntry(testFile.getName()); + zipOut.putNextEntry(zipEntry); + byte[] bytes = new byte[1024]; + int length; + while((length = fis.read(bytes)) >= 0) { + zipOut.write(bytes, 0, length); + } + zipOut.close(); + fis.close(); + fos.close(); + + //now unzip to a directory that doesn't previously exist + File unzipDir = new File(testFile.getParentFile(),"unzipTo"); + ArchiveUtils.unzipFileTo(zipFile.getAbsolutePath(),unzipDir.getAbsolutePath()); + } +} diff --git a/nd4j/nd4j-context/pom.xml b/nd4j/nd4j-context/pom.xml deleted file mode 100644 index 225b3784c..000000000 --- a/nd4j/nd4j-context/pom.xml +++ /dev/null @@ -1,48 +0,0 @@ - - - - - nd4j - org.nd4j - 1.0.0-SNAPSHOT - - 4.0.0 - - nd4j-context - jar - - nd4j-context - - - - org.nd4j - nd4j-common - ${project.version} - - - - - 1.7 - 1.7 - - - - - testresources - - - diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml index 2d0cd6afc..b0941300a 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml @@ -55,18 +55,47 @@ hsqldb ${hsqldb.version} + + + org.nd4j + nd4j-common-tests + ${project.version} + test + testresources - + + + + nd4j-testresources + + + + nd4j-tests-cpu + + false + org.nd4j nd4j-native ${project.version} - test + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java index 68846ae74..4279b7a78 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java @@ -20,6 +20,7 @@ import org.hsqldb.jdbc.JDBCDataSource; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,7 +33,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; -public class HSqlLoaderTest { +public class HSqlLoaderTest extends BaseND4JTest { private static HsqlLoader hsqlLoader; private static DataSource dataSource; @@ -125,7 +126,7 @@ public class HSqlLoaderTest { INDArray load = hsqlLoader.load(hsqlLoader.loadForID("1")); assertNotNull(load); - assertEquals(Nd4j.linspace(1,4,4, Nd4j.dataType()),load); + assertEquals(Nd4j.linspace(1,4,4, load.dataType()),load); } diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml index 44f146f38..0ece4c8b0 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml @@ -38,6 +38,13 @@ junit test + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java index db554ea43..59aea4ac0 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java @@ -19,6 +19,7 @@ package org.nd4j.jdbc.mysql; import com.mchange.v2.c3p0.ComboPooledDataSource; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -26,7 +27,7 @@ import java.sql.Blob; import static org.junit.Assert.assertEquals; -public class MysqlLoaderTest { +public class MysqlLoaderTest extends BaseND4JTest { //simple litmus test, unfortunately relies on an external database diff --git a/nd4j/nd4j-jdbc/pom.xml b/nd4j/nd4j-jdbc/pom.xml index 05ef09942..a382cf0e8 100644 --- a/nd4j/nd4j-jdbc/pom.xml +++ b/nd4j/nd4j-jdbc/pom.xml @@ -53,6 +53,18 @@ testresources + + + nd4j-testresources + + + + nd4j-tests-cpu + + + + nd4j-tests-cuda + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml index 21b3f6b65..ccabf1b76 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml @@ -68,6 +68,13 @@ ${logback.version} test + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/BaseNd4jTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/BaseNd4jTest.java deleted file mode 100644 index 36958198d..000000000 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/BaseNd4jTest.java +++ /dev/null @@ -1,144 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.parameterserver; - - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.nd4j.config.ND4JSystemProperties; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.profiler.ProfilerConfig; - -import java.lang.management.ManagementFactory; -import java.util.List; -import java.util.Map; -import java.util.Properties; - - -/** - * Base Nd4j test - * @author Adam Gibson - */ -@Slf4j -public abstract class BaseNd4jTest { - - @Rule - public TestName testName = new TestName(); - - protected long startTime; - protected int threadCountBefore; - - public BaseNd4jTest(){ - //Suppress ND4J initialization - don't need this logged for every test... - System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false"); - System.gc(); - } - - - @Before - public void before() throws Exception { - log.info("Running " + getClass().getName() + "." + testName.getMethodName()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j.getExecutioner().enableDebugMode(false); - Nd4j.getExecutioner().enableVerboseMode(false); - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void after() throws Exception { - long totalTime = System.currentTimeMillis() - startTime; - Nd4j.getMemoryManager().purgeCaches(); - - logTestCompletion(totalTime); - Nd4j.getExecutioner().enableDebugMode(false); - Nd4j.getExecutioner().enableVerboseMode(false); - - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - val currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - } - - public void logTestCompletion( long totalTime){ - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - sb.append(getClass().getSimpleName()).append(".").append(testName.getMethodName()) - .append(": ").append(totalTime).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } -} diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java index 530aa33ff..6398cbab1 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java @@ -25,10 +25,10 @@ import org.agrona.concurrent.BusySpinIdleStrategy; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.parameterserver.BaseNd4jTest; import org.nd4j.parameterserver.client.ParameterServerClient; import java.util.concurrent.atomic.AtomicInteger; @@ -39,7 +39,7 @@ import static org.junit.Assert.assertEquals; * Created by agibsonccc on 10/5/16. */ @Slf4j -public class RemoteParameterServerClientTests extends BaseNd4jTest { +public class RemoteParameterServerClientTests extends BaseND4JTest { private int parameterLength = 1000; private Aeron.Context ctx; private MediaDriver mediaDriver; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java index d69455a57..3846d6c2d 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java @@ -24,11 +24,11 @@ import org.agrona.concurrent.BusySpinIdleStrategy; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.parameterserver.BaseNd4jTest; import org.nd4j.parameterserver.ParameterServerListener; import org.nd4j.parameterserver.ParameterServerSubscriber; @@ -40,7 +40,7 @@ import static org.junit.Assert.assertTrue; * Created by agibsonccc on 10/3/16. */ @Slf4j -public class ParameterServerClientPartialTest extends BaseNd4jTest { +public class ParameterServerClientPartialTest extends BaseND4JTest { private static MediaDriver mediaDriver; private static Aeron.Context ctx; private static ParameterServerSubscriber masterNode, slaveNode; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java index 1ec08a83c..6aeaadb90 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java @@ -21,10 +21,10 @@ import io.aeron.driver.MediaDriver; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.parameterserver.BaseNd4jTest; import org.nd4j.parameterserver.ParameterServerListener; import org.nd4j.parameterserver.ParameterServerSubscriber; import org.slf4j.Logger; @@ -37,7 +37,7 @@ import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 10/3/16. */ -public class ParameterServerClientTest extends BaseNd4jTest { +public class ParameterServerClientTest extends BaseND4JTest { private static MediaDriver mediaDriver; private static Logger log = LoggerFactory.getLogger(ParameterServerClientTest.class); private static Aeron aeron; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml index f0d44beb6..bfab849da 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml @@ -79,6 +79,13 @@ rxjava 2.2.0 + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java index 8b27c87ea..b1d159c1d 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java @@ -22,6 +22,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; @@ -60,7 +61,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore @Deprecated -public class VoidParameterServerStressTest { +public class VoidParameterServerStressTest extends BaseND4JTest { private static final int NUM_WORDS = 100000; @Before diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java index f85e23d91..57c79e3ba 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java @@ -21,6 +21,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; @@ -54,7 +55,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore @Deprecated -public class VoidParameterServerTest { +public class VoidParameterServerTest extends BaseND4JTest { private static List localIPs; private static List badIPs; private static final Transport transport = new MulticastTransport(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java index 6253d776b..0cb4b254a 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java @@ -20,6 +20,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.exception.ND4JIllegalStateException; import static org.junit.Assert.*; @@ -27,7 +28,7 @@ import static org.junit.Assert.*; /** * @author raver119@gmail.com */ -public class VoidConfigurationTest { +public class VoidConfigurationTest extends BaseND4JTest { @Rule public Timeout globalTimeout = Timeout.seconds(30); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java index 4041bb1e8..26b9da8eb 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; import org.junit.*; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.ArrayUtil; @@ -40,7 +41,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore @Deprecated -public class ClipboardTest { +public class ClipboardTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java index ad3f23c1e..5c9694995 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java @@ -21,6 +21,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.distributed.logic.completion.FrameCompletionHandler; import static org.junit.Assert.*; @@ -30,7 +31,7 @@ import static org.junit.Assert.*; */ @Ignore @Deprecated -public class FrameCompletionHandlerTest { +public class FrameCompletionHandlerTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java index d2211f9fa..607ee96e9 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java @@ -21,6 +21,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.StringUtils; import org.nd4j.linalg.util.HashUtil; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; @@ -39,7 +40,7 @@ import static org.junit.Assert.*; */ @Ignore @Deprecated -public class InterleavedRouterTest { +public class InterleavedRouterTest extends BaseND4JTest { VoidConfiguration configuration; Transport transport; long originator; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java index 9c4ecec51..0e46bb179 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java @@ -20,6 +20,7 @@ import org.agrona.concurrent.UnsafeBuffer; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.enums.NodeRole; import org.nd4j.parameterserver.distributed.logic.completion.Clipboard; @@ -37,7 +38,7 @@ import static org.junit.Assert.*; */ @Ignore @Deprecated -public class FrameTest { +public class FrameTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java index eab7aa468..d23ecd047 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java @@ -20,6 +20,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage; import static org.junit.Assert.*; @@ -29,7 +30,7 @@ import static org.junit.Assert.*; */ @Ignore @Deprecated -public class VoidMessageTest { +public class VoidMessageTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java index f5914cde1..698ab6dc1 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed.messages.aggregations; import lombok.extern.slf4j.Slf4j; import org.junit.*; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -33,7 +34,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore @Deprecated -public class VoidAggregationTest { +public class VoidAggregationTest extends BaseND4JTest { private static final short NODES = 100; private static final int ELEMENTS_PER_NODE = 3; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java index 7d9092591..661c74196 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java @@ -20,6 +20,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.enums.NodeRole; import org.nd4j.parameterserver.distributed.logic.ClientRouter; @@ -39,7 +40,7 @@ import static org.junit.Assert.*; */ @Ignore @Deprecated -public class RoutedTransportTest { +public class RoutedTransportTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java index 1be838edf..c7ace4db5 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; import org.junit.*; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -29,7 +30,7 @@ import static org.junit.Assert.*; * @author raver119@gmail.com */ @Slf4j -public class NetworkOrganizerTest { +public class NetworkOrganizerTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java index a75d00909..20831d47c 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java @@ -23,6 +23,7 @@ import org.apache.commons.lang3.RandomUtils; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.AtomicBoolean; @@ -48,7 +49,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @Slf4j -public class DelayedModelParameterServerTest { +public class DelayedModelParameterServerTest extends BaseND4JTest { private static final String rootId = "ROOT_NODE"; @Before diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java index 54acd8adf..c6be1b9c9 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java @@ -20,6 +20,7 @@ import io.reactivex.functions.Consumer; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.AtomicBoolean; @@ -46,7 +47,7 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.*; @Slf4j -public class ModelParameterServerTest { +public class ModelParameterServerTest extends BaseND4JTest { private static final String rootId = "ROOT_NODE"; @Test(timeout = 20000L) diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java index 86e4b8fa4..0d48d5358 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed.v2.chunks.impl; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk; import org.nd4j.parameterserver.distributed.v2.chunks.impl.FileChunksTracker; @@ -30,7 +31,7 @@ import java.util.ArrayList; import static org.junit.Assert.*; @Slf4j -public class FileChunksTrackerTest { +public class FileChunksTrackerTest extends BaseND4JTest { @Test public void testTracker_1() throws Exception { val array = Nd4j.linspace(1, 100000, 100000).reshape(-1, 1000); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java index aa78d721f..3a2430ed1 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java @@ -18,6 +18,7 @@ package org.nd4j.parameterserver.distributed.v2.chunks.impl; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk; import org.nd4j.parameterserver.distributed.v2.messages.impl.GradientsUpdateMessage; @@ -27,7 +28,7 @@ import java.util.ArrayList; import static org.junit.Assert.*; -public class InmemoryChunksTrackerTest { +public class InmemoryChunksTrackerTest extends BaseND4JTest { @Test public void testTracker_1() throws Exception { val array = Nd4j.linspace(1, 100000, 10000).reshape(-1, 1000); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java index 8380f157c..2b08b0c62 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java @@ -19,13 +19,14 @@ package org.nd4j.parameterserver.distributed.v2.messages; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.util.SerializationUtils; import org.nd4j.parameterserver.distributed.v2.messages.pairs.handshake.HandshakeRequest; import static org.junit.Assert.*; @Slf4j -public class VoidMessageTest { +public class VoidMessageTest extends BaseND4JTest { @Test public void testHandshakeSerialization_1() throws Exception { val req = new HandshakeRequest(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java index 901d4b197..c6999e469 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java @@ -19,11 +19,12 @@ package org.nd4j.parameterserver.distributed.v2.messages.history; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.*; @Slf4j -public class HashHistoryHolderTest { +public class HashHistoryHolderTest extends BaseND4JTest { @Test public void testBasicStuff_1() { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java index 990acbcd7..bd780e0ef 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java @@ -20,13 +20,14 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.v2.messages.pairs.ping.PingMessage; import static org.junit.Assert.*; @Slf4j -public class AeronUdpTransportTest { +public class AeronUdpTransportTest extends BaseND4JTest { private static final String IP = "127.0.0.1"; private static final int ROOT_PORT = 40781; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java index b2f530df4..ba34f974f 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed.v2.transport.impl; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.v2.enums.PropagationMode; import org.nd4j.parameterserver.distributed.v2.messages.VoidMessage; @@ -33,7 +34,7 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.*; @Slf4j -public class DummyTransportTest { +public class DummyTransportTest extends BaseND4JTest { @Test public void testBasicConnection_1() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java index 63bdb9fc6..72487bf2a 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed.v2.util; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.util.SerializationUtils; import org.nd4j.parameterserver.distributed.v2.enums.MeshBuildMode; @@ -32,7 +33,7 @@ import static org.junit.Assert.*; * @author raver119@gmail.com */ @Slf4j -public class MeshOrganizerTest { +public class MeshOrganizerTest extends BaseND4JTest { @Test(timeout = 1000L) public void testDescendantsCount_1() { @@ -165,8 +166,8 @@ public class MeshOrganizerTest { nodes.add(node); } - for (val n:nodes) - log.info("Number of downstreams: [{}]", n.numberOfDownstreams()); +// for (val n:nodes) +// log.info("Number of downstreams: [{}]", n.numberOfDownstreams()); log.info("Going for first clone"); val clone1 = mesh.clone(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java index 20038d1a6..71761291e 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed.v2.util; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Atomic; import org.nd4j.linalg.primitives.Optional; @@ -29,7 +30,7 @@ import java.util.ArrayList; import static org.junit.Assert.*; @Slf4j -public class MessageSplitterTest { +public class MessageSplitterTest extends BaseND4JTest { @Test public void testMessageSplit_1() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java index 2f0623791..52e84efd3 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java @@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; @@ -38,7 +39,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore @Deprecated -public class ParameterServerNodeTest { +public class ParameterServerNodeTest extends BaseND4JTest { private static MediaDriver mediaDriver; private static Aeron aeron; private static ParameterServerNode parameterServerNode; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml index 6996672ca..619af0d7b 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml @@ -46,6 +46,13 @@ junit test + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java index 3a7ac9ea3..170e1583f 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java @@ -17,6 +17,7 @@ package org.nd4j.parameterserver.updater.storage; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; @@ -25,7 +26,7 @@ import static junit.framework.TestCase.assertEquals; /** * Created by agibsonccc on 12/2/16. */ -public class UpdaterStorageTests { +public class UpdaterStorageTests extends BaseND4JTest { @Test(timeout = 30000L) public void testInMemory() { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml index 62bb98c1c..4ed7c8b7b 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml @@ -95,6 +95,13 @@ + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java index b821b2571..23ee67d78 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java @@ -17,12 +17,13 @@ package org.nd4j.parameterserver.status.play; import org.junit.Test; +import org.nd4j.BaseND4JTest; import play.server.Server; /** * Created by agibsonccc on 12/1/16. */ -public class StatusServerTests { +public class StatusServerTests extends BaseND4JTest { @Test(timeout = 20000L) public void runStatusServer() { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java index cd9aab453..1ad1e478e 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java @@ -17,6 +17,7 @@ package org.nd4j.parameterserver.status.play; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.model.SubscriberState; import static junit.framework.TestCase.assertEquals; @@ -25,7 +26,7 @@ import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 12/1/16. */ -public class StorageTests { +public class StorageTests extends BaseND4JTest { @Test(timeout = 20000L) public void testMapStorage() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml index af7316a37..4c60d83e2 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml @@ -58,6 +58,13 @@ unirest-java ${unirest.version} + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java index f2d47e95b..58627405b 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java @@ -17,6 +17,7 @@ package org.nd4j.parameterserver.updater; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.aeron.ndarrayholder.InMemoryNDArrayHolder; import org.nd4j.linalg.factory.Nd4j; @@ -29,7 +30,7 @@ import static org.junit.Assume.assumeNotNull; /** * Created by agibsonccc on 12/2/16. */ -public class ParameterServerUpdaterTests { +public class ParameterServerUpdaterTests extends BaseND4JTest { @Test(timeout = 30000L) public void synchronousTest() { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java index 5186f1a4f..d666f7cb4 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java @@ -17,6 +17,7 @@ package org.nd4j.parameterserver.updater.storage; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; @@ -25,7 +26,7 @@ import static junit.framework.TestCase.assertEquals; /** * Created by agibsonccc on 12/2/16. */ -public class UpdaterStorageTests { +public class UpdaterStorageTests extends BaseND4JTest { @Test(expected = UnsupportedOperationException.class) diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml index 93d171535..88af3c166 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml +++ b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml @@ -83,6 +83,13 @@ ${logback.version} test + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java index 1c84da010..bbd6e8cf7 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java +++ b/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java @@ -21,6 +21,7 @@ import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.execution.input.Operands; @@ -32,7 +33,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore -public class GraphInferenceGrpcClientTest { +public class GraphInferenceGrpcClientTest extends BaseND4JTest { @Test public void testSimpleGraph_1() throws Exception { val exp = Nd4j.create(new double[] {-0.95938617, -1.20301781, 1.22260064, 0.50172403, 0.59972949, 0.78568028, 0.31609724, 1.51674747, 0.68013491, -0.05227458, 0.25903158,1.13243439}, new long[]{3, 1, 4}); diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml index 6c307f71c..d7729f179 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/pom.xml +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -130,6 +130,14 @@ ${gson.version} test + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java index 290d1c4a1..1abbff6b3 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java @@ -19,6 +19,7 @@ package org.nd4j.remote; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.remote.clients.JsonRemoteInference; @@ -35,7 +36,7 @@ import java.util.concurrent.Future; import static org.junit.Assert.*; @Slf4j -public class SameDiffJsonModelServerTest { +public class SameDiffJsonModelServerTest extends BaseND4JTest { @Test public void basicServingTest_1() throws Exception { diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java index a26809efc..9a7b94272 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java @@ -7,6 +7,7 @@ import org.apache.http.impl.client.HttpClientBuilder; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.MultiDataSet; @@ -18,7 +19,7 @@ import java.io.IOException; import static org.junit.Assert.assertEquals; -public class SameDiffServletTest { +public class SameDiffServletTest extends BaseND4JTest { private SameDiffJsonModelServer server; diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java index 3aca94b8a..8332918b6 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java @@ -2,12 +2,13 @@ package org.nd4j.remote.serde; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.remote.clients.serde.impl.*; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -public class BasicSerdeTests { +public class BasicSerdeTests extends BaseND4JTest { private final static DoubleArraySerde doubleArraySerde = new DoubleArraySerde(); private final static FloatArraySerde floatArraySerde = new FloatArraySerde(); private final static StringSerde stringSerde = new StringSerde(); diff --git a/nd4j/nd4j-serde/nd4j-aeron/pom.xml b/nd4j/nd4j-serde/nd4j-aeron/pom.xml index fcf74c518..87e9347dd 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/pom.xml +++ b/nd4j/nd4j-serde/nd4j-aeron/pom.xml @@ -49,15 +49,105 @@ testresources - + + + + nd4j-tests-cpu + + false + org.nd4j nd4j-native ${project.version} - test + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + -Ddtype=float -Xmx8g + + + + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + org.apache.maven.surefire + surefire-junit47 + 2.19.1 + + + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.jcublas.JCublasBackend + org.nd4j.linalg.jcublas.JCublasBackend + + + -Ddtype=float -Xmx6g + + + + @@ -78,5 +168,27 @@ junit test + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + ch.qos.logback + logback-core + ${logback.version} + test + + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java index e05a5b6f9..78145eb8a 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java @@ -18,7 +18,9 @@ package org.nd4j.aeron.ipc; import org.agrona.concurrent.UnsafeBuffer; import org.apache.commons.lang3.time.StopWatch; +import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,7 +34,7 @@ import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 9/23/16. */ -public class AeronNDArraySerdeTest { +public class AeronNDArraySerdeTest extends BaseND4JTest { @Test public void testToAndFrom() { @@ -56,7 +58,10 @@ public class AeronNDArraySerdeTest { @Test + @Ignore // timeout, skip step ignored public void testToAndFromCompressedLarge() { + skipUnlessIntegrationTests(); + INDArray arr = Nd4j.zeros((int) 1e7); INDArray compress = Nd4j.getCompressor().compress(arr, "GZIP"); assertTrue(compress.isCompressed()); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java index 55e4e13ff..bd3ac9b9c 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java @@ -23,6 +23,7 @@ import org.agrona.CloseHelper; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -34,29 +35,40 @@ import static org.junit.Assert.assertFalse; * Created by agibsonccc on 9/22/16. */ @Slf4j -public class LargeNdArrayIpcTest { +public class LargeNdArrayIpcTest extends BaseND4JTest { private MediaDriver mediaDriver; private Aeron.Context ctx; private String channel = "aeron:udp?endpoint=localhost:" + (40123 + new java.util.Random().nextInt(130)); private int streamId = 10; private int length = (int) 1e7; + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Before public void before() { - //MediaDriver.loadPropertiesFile("aeron.properties"); - MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); - mediaDriver = MediaDriver.launchEmbedded(ctx); - System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); - System.out.println("Launched media driver"); + if(isIntegrationTests()) { + //MediaDriver.loadPropertiesFile("aeron.properties"); + MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); + mediaDriver = MediaDriver.launchEmbedded(ctx); + System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); + System.out.println("Launched media driver"); + } } @After public void after() { - CloseHelper.quietClose(mediaDriver); + if(isIntegrationTests()) { + CloseHelper.quietClose(mediaDriver); + } } @Test public void testMultiThreadedIpcBig() throws Exception { + skipUnlessIntegrationTests(); //Long-running test - don't run as part of unit tests by default + int length = (int) 1e7; INDArray arr = Nd4j.ones(length); AeronNDArrayPublisher publisher; diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java index 0cddea16b..be0572c93 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java @@ -18,6 +18,7 @@ package org.nd4j.aeron.ipc; import org.agrona.DirectBuffer; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -26,7 +27,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 11/6/16. */ -public class NDArrayMessageTest { +public class NDArrayMessageTest extends BaseND4JTest { @Test public void testNDArrayMessageToAndFrom() { diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java index 0412b1bed..d1a95ffde 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java @@ -22,6 +22,7 @@ import org.agrona.CloseHelper; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertFalse; /** * Created by agibsonccc on 9/22/16. */ -public class NdArrayIpcTest { +public class NdArrayIpcTest extends BaseND4JTest { private MediaDriver mediaDriver; private static Logger log = LoggerFactory.getLogger(NdArrayIpcTest.class); private Aeron.Context ctx; @@ -44,21 +45,32 @@ public class NdArrayIpcTest { private int streamId = 10; private int length = (int) 1e7; + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Before public void before() { - MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); - mediaDriver = MediaDriver.launchEmbedded(ctx); - System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); - System.out.println("Launched media driver"); + if(isIntegrationTests()) { + MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); + mediaDriver = MediaDriver.launchEmbedded(ctx); + System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); + System.out.println("Launched media driver"); + } } @After public void after() { - CloseHelper.quietClose(mediaDriver); + if(isIntegrationTests()) { + CloseHelper.quietClose(mediaDriver); + } } @Test public void testMultiThreadedIpc() throws Exception { + skipUnlessIntegrationTests(); //Long-running test - don't run as part of unit tests by default + ExecutorService executorService = Executors.newFixedThreadPool(4); INDArray arr = Nd4j.scalar(1.0); @@ -139,6 +151,8 @@ public class NdArrayIpcTest { @Test public void testIpc() throws Exception { + skipUnlessIntegrationTests(); //Long-running test - don't run as part of unit tests by default + INDArray arr = Nd4j.scalar(1.0); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java index 8890d3836..d856b8f95 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java @@ -17,6 +17,7 @@ package org.nd4j.aeron.ipc.chunk; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; @@ -25,7 +26,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 11/20/16. */ -public class ChunkAccumulatorTests { +public class ChunkAccumulatorTests extends BaseND4JTest { @Test public void testAccumulator() { diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java index 3aeb37ee0..314ed0243 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java @@ -18,6 +18,7 @@ package org.nd4j.aeron.ipc.chunk; import org.agrona.DirectBuffer; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.aeron.util.BufferUtil; import org.nd4j.linalg.factory.Nd4j; @@ -30,7 +31,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 11/20/16. */ -public class NDArrayMessageChunkTests { +public class NDArrayMessageChunkTests extends BaseND4JTest { @Test public void testChunkSerialization() { diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java index f843f3839..b1df7d193 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java @@ -24,6 +24,7 @@ import org.agrona.CloseHelper; import org.agrona.concurrent.BusySpinIdleStrategy; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.*; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -37,24 +38,33 @@ import static org.junit.Assert.assertEquals; * Created by agibsonccc on 10/3/16. */ @Slf4j -public class AeronNDArrayResponseTest { +public class AeronNDArrayResponseTest extends BaseND4JTest { private MediaDriver mediaDriver; + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Before public void before() { - final MediaDriver.Context ctx = - new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true) - .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) - .receiverIdleStrategy(new BusySpinIdleStrategy()) - .senderIdleStrategy(new BusySpinIdleStrategy()); - mediaDriver = MediaDriver.launchEmbedded(ctx); - System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); - System.out.println("Launched media driver"); + if(isIntegrationTests()) { + final MediaDriver.Context ctx = + new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true) + .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) + .receiverIdleStrategy(new BusySpinIdleStrategy()) + .senderIdleStrategy(new BusySpinIdleStrategy()); + mediaDriver = MediaDriver.launchEmbedded(ctx); + System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); + System.out.println("Launched media driver"); + } } @Test public void testResponse() throws Exception { + skipUnlessIntegrationTests(); //Long-running test - don't run as part of unit tests by default + int streamId = 10; int responderStreamId = 11; String host = "127.0.0.1"; diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml index f16583745..ddadc2df1 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml +++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml @@ -54,20 +54,118 @@ arrow-format ${arrow.version} + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + testresources - + + + + nd4j-tests-cpu + + false + org.nd4j nd4j-native ${project.version} - test + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + -Ddtype=float -Xmx8g + + + + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + org.apache.maven.surefire + surefire-junit47 + 2.19.1 + + + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.jcublas.JCublasBackend + org.nd4j.linalg.jcublas.JCublasBackend + + + -Ddtype=float -Xmx6g + + + + diff --git a/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java b/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java index 6c3328ea8..eee1521bd 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java +++ b/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java @@ -18,12 +18,13 @@ package org.nd4j.arrow; import org.apache.arrow.flatbuf.Tensor; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; -public class ArrowSerdeTest { +public class ArrowSerdeTest extends BaseND4JTest { @Test public void testBackAndForth() { diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml index 8ccd6b574..d7f3d0887 100644 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml +++ b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml @@ -76,6 +76,14 @@ + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/KafkaConnectionInformation.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/KafkaConnectionInformation.java deleted file mode 100644 index 9f30cb549..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/KafkaConnectionInformation.java +++ /dev/null @@ -1,51 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.camel.kafka; - -import kafka.serializer.StringEncoder; -import lombok.Builder; -import lombok.Data; - -import java.io.Serializable; - -/** - * Kafka connection information - * to generate camel uris - * - * @author Adam Gibson - */ -@Builder -@Data -public class KafkaConnectionInformation implements Serializable { - private String zookeeperHost; - private int zookeeperPort; - private String kafkaBrokerList; - private String topicName; - private String groupId; - - /** - * Returns a kafka connection uri - * @return a kafka connection uri - * represented by this connection information - */ - public String kafkaUri() { - return String.format( - "kafka://%s?topic=%s&groupId=%s&zookeeperHost=%s&zookeeperPort=%d&serializerClass=%s&keySerializerClass=%s", - kafkaBrokerList, topicName, groupId, zookeeperHost, zookeeperPort, - StringEncoder.class.getName(), StringEncoder.class.getName()); - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaProducer.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaProducer.java deleted file mode 100644 index 4120abacb..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaProducer.java +++ /dev/null @@ -1,48 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.camel.kafka; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import org.apache.camel.CamelContext; -import org.apache.camel.ProducerTemplate; -import org.nd4j.linalg.api.ndarray.INDArray; - -/** - * Created by agibsonccc on 7/19/16. - */ -@AllArgsConstructor -@Builder -public class Nd4jKafkaProducer { - - private KafkaConnectionInformation connectionInformation; - private CamelContext camelContext; - private ProducerTemplate producerTemplate; - - /** - * Publish to a kafka topic - * based on the connection information - * @param arr - */ - public void publish(INDArray arr) { - if (producerTemplate == null) - producerTemplate = camelContext.createProducerTemplate(); - producerTemplate.sendBody("direct:start", arr); - } - - -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaRoute.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaRoute.java deleted file mode 100644 index 909db800d..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaRoute.java +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.camel.kafka; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import org.apache.camel.Exchange; -import org.apache.camel.Processor; -import org.apache.camel.builder.RouteBuilder; -import org.apache.camel.component.kafka.KafkaConstants; -import org.apache.commons.net.util.Base64; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; -import java.util.UUID; - -/** - * Sends a test ndarray - * to kafka - * - * @author Adam Gibson - */ -@AllArgsConstructor -@Builder -public class Nd4jKafkaRoute extends RouteBuilder { - private KafkaConnectionInformation kafkaConnectionInformation; - - @Override - public void configure() throws Exception { - final String kafkaUri = kafkaConnectionInformation.kafkaUri(); - from("direct:start").process(new Processor() { - @Override - public void process(Exchange exchange) throws Exception { - final INDArray arr = (INDArray) exchange.getIn().getBody(); - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(bos); - Nd4j.write(arr, dos); - byte[] bytes = bos.toByteArray(); - String base64 = Base64.encodeBase64String(bytes); - exchange.getIn().setBody(base64, String.class); - String id = UUID.randomUUID().toString(); - exchange.getIn().setHeader(KafkaConstants.KEY, id); - exchange.getIn().setHeader(KafkaConstants.PARTITION_KEY, id); - } - }).to(kafkaUri); - - from(kafkaUri).process(new Processor() { - @Override - public void process(Exchange exchange) throws Exception { - byte[] body2 = (byte[]) exchange.getIn().getBody(); - String body = new String(body2); - INDArray arr = Nd4jBase64.fromBase64(body); - exchange.getIn().setBody(arr); - } - }).to("direct:receive"); - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedKafkaCluster.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedKafkaCluster.java deleted file mode 100644 index 3e01b1e6c..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedKafkaCluster.java +++ /dev/null @@ -1,177 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.kafka; - -import kafka.admin.AdminUtils; -import kafka.server.KafkaConfig; -import kafka.server.KafkaServer; -import org.I0Itec.zkclient.ZkClient; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.io.FileNotFoundException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Properties; - -public class EmbeddedKafkaCluster { - private static final Logger LOG = LoggerFactory.getLogger(EmbeddedKafkaCluster.class); - - private final List ports; - private final String zkConnection; - private final Properties baseProperties; - - private final String brokerList; - - private final List brokers; - private final List logDirs; - - public EmbeddedKafkaCluster(String zkConnection) { - this(zkConnection, new Properties()); - } - - public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties) { - this(zkConnection, baseProperties, Collections.singletonList(-1)); - } - - public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties, List ports) { - this.zkConnection = zkConnection; - this.ports = resolvePorts(ports); - this.baseProperties = baseProperties; - this.brokers = new ArrayList(); - this.logDirs = new ArrayList(); - - this.brokerList = constructBrokerList(this.ports); - } - - public ZkClient getZkClient() { - for (KafkaServer server : brokers) { - return server.zkClient(); - } - return null; - } - - public void createTopics(String... topics) { - for (String topic : topics) { - AdminUtils.createTopic(getZkClient(), topic, 2, 1, new Properties()); - } - } - - private List resolvePorts(List ports) { - List resolvedPorts = new ArrayList(); - for (Integer port : ports) { - resolvedPorts.add(resolvePort(port)); - } - return resolvedPorts; - } - - private int resolvePort(int port) { - if (port == -1) { - return TestUtils.getAvailablePort(); - } - return port; - } - - private String constructBrokerList(List ports) { - StringBuilder sb = new StringBuilder(); - for (Integer port : ports) { - if (sb.length() > 0) { - sb.append(","); - } - sb.append("localhost:").append(port); - } - return sb.toString(); - } - - public void startup() { - for (int i = 0; i < ports.size(); i++) { - Integer port = ports.get(i); - File logDir = TestUtils.constructTempDir("kafka-local"); - - Properties properties = new Properties(); - properties.putAll(baseProperties); - properties.setProperty("zookeeper.connect", zkConnection); - properties.setProperty("broker.id", String.valueOf(i + 1)); - properties.setProperty("host.opName", "localhost"); - properties.setProperty("port", Integer.toString(port)); - properties.setProperty("log.dir", logDir.getAbsolutePath()); - properties.setProperty("num.partitions", String.valueOf(1)); - properties.setProperty("auto.create.topics.enable", String.valueOf(Boolean.TRUE)); - properties.setProperty("log.flush.interval.messages", String.valueOf(1)); - LOG.info("EmbeddedKafkaCluster: local directory: " + logDir.getAbsolutePath()); - - KafkaServer broker = startBroker(properties); - - brokers.add(broker); - logDirs.add(logDir); - } - } - - - private KafkaServer startBroker(Properties props) { - KafkaServer server = new KafkaServer(new KafkaConfig(props), new SystemTime()); - server.startup(); - return server; - } - - public Properties getProps() { - Properties props = new Properties(); - props.putAll(baseProperties); - props.put("metadata.broker.list", brokerList); - props.put("zookeeper.connect", zkConnection); - return props; - } - - public String getBrokerList() { - return brokerList; - } - - public List getPorts() { - return ports; - } - - public String getZkConnection() { - return zkConnection; - } - - public void shutdown() { - for (KafkaServer broker : brokers) { - try { - broker.shutdown(); - } catch (Exception e) { - e.printStackTrace(); - } - } - for (File logDir : logDirs) { - try { - TestUtils.deleteFile(logDir); - } catch (FileNotFoundException e) { - e.printStackTrace(); - } - } - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder("EmbeddedKafkaCluster{"); - sb.append("brokerList='").append(brokerList).append('\''); - sb.append('}'); - return sb.toString(); - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedZookeeper.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedZookeeper.java deleted file mode 100644 index a48b37e5f..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedZookeeper.java +++ /dev/null @@ -1,112 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.kafka; - -import org.apache.zookeeper.server.ServerCnxnFactory; -import org.apache.zookeeper.server.ZooKeeperServer; - -import java.io.File; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.net.InetSocketAddress; - -public class EmbeddedZookeeper { - private int port = -1; - private int tickTime = 500; - - private ServerCnxnFactory factory; - private File snapshotDir; - private File logDir; - - public EmbeddedZookeeper() { - this(-1); - } - - public EmbeddedZookeeper(int port) { - this(port, 500); - } - - public EmbeddedZookeeper(int port, int tickTime) { - this.port = resolvePort(port); - this.tickTime = tickTime; - } - - private int resolvePort(int port) { - if (port == -1) { - return TestUtils.getAvailablePort(); - } - return port; - } - - public void startup() throws IOException { - if (this.port == -1) { - this.port = TestUtils.getAvailablePort(); - } - this.factory = ServerCnxnFactory.createFactory(new InetSocketAddress("localhost", port), 1024); - this.snapshotDir = TestUtils.constructTempDir("embeeded-zk/snapshot"); - this.logDir = TestUtils.constructTempDir("embeeded-zk/log"); - - try { - factory.startup(new ZooKeeperServer(snapshotDir, logDir, tickTime)); - } catch (InterruptedException e) { - throw new IOException(e); - } - } - - - public void shutdown() { - factory.shutdown(); - try { - TestUtils.deleteFile(snapshotDir); - } catch (FileNotFoundException e) { - // ignore - } - try { - TestUtils.deleteFile(logDir); - } catch (FileNotFoundException e) { - // ignore - } - } - - public String getConnection() { - return "localhost:" + port; - } - - public void setPort(int port) { - this.port = port; - } - - public void setTickTime(int tickTime) { - this.tickTime = tickTime; - } - - public int getPort() { - return port; - } - - public int getTickTime() { - return tickTime; - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder("EmbeddedZookeeper{"); - sb.append("connection=").append(getConnection()); - sb.append('}'); - return sb.toString(); - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java index 14c6d5875..3b40661bb 100644 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java +++ b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java @@ -21,6 +21,7 @@ import org.apache.camel.impl.DefaultCamelContext; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.camel.kafka.KafkaConnectionInformation; import org.nd4j.camel.kafka.Nd4jKafkaConsumer; import org.nd4j.camel.kafka.Nd4jKafkaProducer; @@ -32,7 +33,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 7/19/16. */ -public class Nd4jKafkaRouteTest { +public class Nd4jKafkaRouteTest extends BaseND4JTest { private EmbeddedKafkaCluster kafka; private EmbeddedZookeeper zk; private CamelContext camelContext; diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/TestUtils.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/TestUtils.java deleted file mode 100644 index a93f18adc..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/TestUtils.java +++ /dev/null @@ -1,64 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.kafka; - -import java.io.File; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.net.ServerSocket; -import java.util.Random; - -public class TestUtils { - private static final Random RANDOM = new Random(); - - private TestUtils() {} - - public static File constructTempDir(String dirPrefix) { - File file = new File(System.getProperty("java.io.tmpdir"), dirPrefix + RANDOM.nextInt(10000000)); - if (!file.mkdirs()) { - throw new RuntimeException("could not create temp directory: " + file.getAbsolutePath()); - } - file.deleteOnExit(); - return file; - } - - public static int getAvailablePort() { - try { - ServerSocket socket = new ServerSocket(0); - try { - return socket.getLocalPort(); - } finally { - socket.close(); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); - } - } - - public static boolean deleteFile(File path) throws FileNotFoundException { - if (!path.exists()) { - throw new FileNotFoundException(path.getAbsolutePath()); - } - boolean ret = true; - if (path.isDirectory()) { - for (File f : path.listFiles()) { - ret = ret && deleteFile(f); - } - } - return ret && path.delete(); - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/pom.xml b/nd4j/nd4j-serde/nd4j-camel-routes/pom.xml deleted file mode 100644 index 94be439c3..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/pom.xml +++ /dev/null @@ -1,40 +0,0 @@ - - - - - nd4j-serde - org.nd4j - 1.0.0-SNAPSHOT - - 4.0.0 - - nd4j-camel-routes - pom - - nd4j-camel-routes - https://deeplearning4j.org - - nd4j-kafka - - - - - testresources - - - - diff --git a/nd4j/nd4j-serde/nd4j-gson/pom.xml b/nd4j/nd4j-serde/nd4j-gson/pom.xml index 82770d51a..f7215436a 100644 --- a/nd4j/nd4j-serde/nd4j-gson/pom.xml +++ b/nd4j/nd4j-serde/nd4j-gson/pom.xml @@ -45,20 +45,118 @@ junit test + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + testresources - + + + + nd4j-tests-cpu + + false + org.nd4j nd4j-native ${project.version} - test + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + -Ddtype=float -Xmx8g + + + + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + org.apache.maven.surefire + surefire-junit47 + 2.19.1 + + + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.jcublas.JCublasBackend + org.nd4j.linalg.jcublas.JCublasBackend + + + -Ddtype=float -Xmx6g + + + + diff --git a/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java b/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java deleted file mode 100644 index 5b087a055..000000000 --- a/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.serde.gson; - -import org.nd4j.shade.guava.primitives.Ints; -import org.nd4j.shade.guava.primitives.Longs; -import com.google.gson.JsonArray; -import com.google.gson.JsonElement; -import com.google.gson.JsonParser; -import org.apache.commons.lang3.StringUtils; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.text.NumberFormat; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; - -/** - * Gson serialization - * - * @author Alex Black - * @author Adam Gibson - */ -public class GsonDeserializationUtils { - private static final JsonParser JSON_PARSER = new JsonParser(); - - static { - NumberFormat format = NumberFormat.getIntegerInstance(); - format.setGroupingUsed(false); - } - - /** - * Deserialize an ndarray - * form json - * @param serializedRawArray - * @return - */ - public static INDArray deserializeRawJson(String serializedRawArray) { - - //String cleanedRawArray = serializedRawArray.replaceAll("(?<=[\\d])(,)(?=[\\d])", ""); - String cleanedRawArray = serializedRawArray; - JsonArray jsonArray = JSON_PARSER.parse(cleanedRawArray).getAsJsonArray(); - - List dimensions = new ArrayList<>(); - dimensions.add(jsonArray.size()); - getSizeMultiDimensionalArray(jsonArray, dimensions); - - return buildArray(dimensions, cleanedRawArray); - } - - /* - The below method works under the following assumption - which is an INDArray can not have a row such as [ 1 , 2, [3, 4] ] - and either all elements of an INDArray are either INDArrays themselves or scalars. - So if that is the case, then it suffices to only check the first element of each JsonArray - to see if that first element is itself an JsonArray. If it is an array, then we must check - the first element of that array to see if it's a scalar or array. - */ - - private static void getSizeMultiDimensionalArray(JsonArray jsonArray, List dimensions) { - Iterator iterator = jsonArray.iterator(); - - if (iterator.hasNext()) { - JsonElement jsonElement = iterator.next(); - if (jsonElement.isJsonArray()) { - JsonArray shapeArray = jsonElement.getAsJsonArray(); - dimensions.add(shapeArray.size()); - getSizeMultiDimensionalArray(shapeArray, dimensions); - } - } - } - - private static boolean isArrayWithSingleRow(List dimensions) { - return dimensions.size() == 1; - } - - private static INDArray buildArray(List dimensions, String rawArray) { - long[] shape = Longs.toArray(dimensions); - String[] entries = StringUtils.replacePattern(rawArray, "[\\[\\]\\n]", "").split(","); - double[] entryValues = new double[entries.length]; - - for (int i = 0; i < entries.length; i++) { - entryValues[i] = Double.parseDouble(entries[i]); - } - - return Nd4j.create(entryValues, shape, Nd4j.defaultFloatingPointType()); - } -} diff --git a/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java b/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java index 0695a6cb8..a64641fe3 100644 --- a/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java +++ b/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java @@ -17,13 +17,14 @@ package org.nd4j.serde.gson; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; -public class GsonDeserializationUtilsTest { +public class GsonDeserializationUtilsTest extends BaseND4JTest { @Test public void deserializeRawJson_PassInInRank3Array_ExpectCorrectDeserialization() { String serializedRawArray = "[[[1.00, 11.00, 3.00],\n" + "[13.00, 5.00, 15.00],\n" + "[7.00, 17.00, 9.00]]]"; diff --git a/nd4j/nd4j-serde/nd4j-jackson/pom.xml b/nd4j/nd4j-serde/nd4j-jackson/pom.xml deleted file mode 100644 index f1ef0aa79..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/pom.xml +++ /dev/null @@ -1,56 +0,0 @@ - - - - - - nd4j-serde - org.nd4j - 1.0.0-SNAPSHOT - - 4.0.0 - - nd4j-jackson - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - provided - - - org.nd4j - nd4j-api - ${project.version} - provided - - - - - org.nd4j - jackson - ${project.version} - jar - - - - - - testresources - - - diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorDeSerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorDeSerializer.java deleted file mode 100644 index db4125d16..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorDeSerializer.java +++ /dev/null @@ -1,59 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson; - -import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.shade.jackson.core.JsonParser; -import org.nd4j.shade.jackson.databind.DeserializationContext; -import org.nd4j.shade.jackson.databind.JsonDeserializer; -import org.nd4j.shade.jackson.databind.JsonNode; - -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.VectorDeSerializer} - */ -public class VectorDeSerializer extends JsonDeserializer { - @Override - public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { - JsonNode node = jp.getCodec().readTree(jp); - JsonNode arr = node.get("dataBuffer"); - int rank = node.get("rankField").asInt(); - int numElements = node.get("numElements").asInt(); - int offset = node.get("offsetField").asInt(); - JsonNode shape = node.get("shapeField"); - JsonNode stride = node.get("strideField"); - int[] realShape = new int[rank]; - int[] realStride = new int[rank]; - DataBuffer buff = Nd4j.createBuffer(numElements); - for (int i = 0; i < numElements; i++) { - buff.put(i, arr.get(i).asDouble()); - } - - String ordering = node.get("orderingField").asText(); - for (int i = 0; i < rank; i++) { - realShape[i] = shape.get(i).asInt(); - realStride[i] = stride.get(i).asInt(); - } - - INDArray ret = Nd4j.create(buff, realShape, realStride, offset, ordering.charAt(0)); - return ret; - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorSerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorSerializer.java deleted file mode 100644 index 651e1980c..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorSerializer.java +++ /dev/null @@ -1,64 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson; - - -import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.shade.jackson.core.JsonGenerator; -import org.nd4j.shade.jackson.databind.JsonSerializer; -import org.nd4j.shade.jackson.databind.SerializerProvider; - -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.VectorSerializer} - */ -public class VectorSerializer extends JsonSerializer { - @Override - public void serialize(INDArray indArray, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) - throws IOException { - if (indArray.isView()) - indArray = indArray.dup(indArray.ordering()); - jsonGenerator.writeStartObject(); - DataBuffer view = indArray.data(); - jsonGenerator.writeArrayFieldStart("dataBuffer"); - for (int i = 0; i < view.length(); i++) { - jsonGenerator.writeNumber(view.getDouble(i)); - } - - jsonGenerator.writeEndArray(); - - jsonGenerator.writeArrayFieldStart("shapeField"); - for (int i = 0; i < indArray.rank(); i++) { - jsonGenerator.writeNumber(indArray.size(i)); - } - jsonGenerator.writeEndArray(); - - jsonGenerator.writeArrayFieldStart("strideField"); - for (int i = 0; i < indArray.rank(); i++) - jsonGenerator.writeNumber(indArray.stride(i)); - jsonGenerator.writeEndArray(); - - jsonGenerator.writeNumberField("offsetField", indArray.offset()); - jsonGenerator.writeNumberField("rankField", indArray.rank()); - jsonGenerator.writeNumberField("numElements", view.length()); - jsonGenerator.writeStringField("orderingField", String.valueOf(indArray.ordering())); - jsonGenerator.writeEndObject(); - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArrayDeSerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArrayDeSerializer.java deleted file mode 100644 index 95560fbb6..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArrayDeSerializer.java +++ /dev/null @@ -1,40 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson.ndarray; - -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.JsonNode; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.ndarray.NDArrayDeSerializer} - */ -public class NDArrayDeSerializer extends JsonDeserializer { - @Override - public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { - JsonNode node = jp.getCodec().readTree(jp); - String field = node.get("array").asText(); - INDArray ret = Nd4jBase64.fromBase64(field); - return ret; - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArraySerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArraySerializer.java deleted file mode 100644 index abb7fc209..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArraySerializer.java +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson.ndarray; - - -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.databind.JsonSerializer; -import com.fasterxml.jackson.databind.SerializerProvider; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.ndarray.NDArraySerializer} - */ -public class NDArraySerializer extends JsonSerializer { - @Override - public void serialize(INDArray indArray, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) - throws IOException { - String toBase64 = Nd4jBase64.base64String(indArray); - jsonGenerator.writeStartObject(); - jsonGenerator.writeStringField("array", toBase64); - jsonGenerator.writeEndObject(); - - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArrayDeSerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArrayDeSerializer.java deleted file mode 100644 index e47a241dd..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArrayDeSerializer.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson.shaded; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; -import org.nd4j.shade.jackson.core.JsonParser; -import org.nd4j.shade.jackson.databind.DeserializationContext; -import org.nd4j.shade.jackson.databind.JsonDeserializer; -import org.nd4j.shade.jackson.databind.JsonNode; - -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.shaded.NDArrayDeSerializer} - */ -@Deprecated -public class NDArrayDeSerializer extends JsonDeserializer { - @Override - public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { - JsonNode node = jp.getCodec().readTree(jp); - String field = node.get("array").asText(); - INDArray ret = Nd4jBase64.fromBase64(field.toString()); - return ret; - - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArraySerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArraySerializer.java deleted file mode 100644 index 6667e6685..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArraySerializer.java +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson.shaded; - - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.serde.base64.Nd4jBase64; -import org.nd4j.shade.jackson.core.JsonGenerator; -import org.nd4j.shade.jackson.databind.JsonSerializer; -import org.nd4j.shade.jackson.databind.SerializerProvider; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.shaded.NDArraySerializer} - */ -@Deprecated -public class NDArraySerializer extends JsonSerializer { - @Override - public void serialize(INDArray indArray, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) - throws IOException { - String toBase64 = Nd4jBase64.base64String(indArray); - jsonGenerator.writeStartObject(); - jsonGenerator.writeStringField("array", toBase64); - jsonGenerator.writeEndObject(); - } -} diff --git a/nd4j/nd4j-serde/nd4j-kryo/pom.xml b/nd4j/nd4j-serde/nd4j-kryo/pom.xml index 850413b1d..25acac26f 100644 --- a/nd4j/nd4j-serde/nd4j-kryo/pom.xml +++ b/nd4j/nd4j-serde/nd4j-kryo/pom.xml @@ -113,20 +113,118 @@ junit test + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + testresources - + + + + nd4j-tests-cpu + + false + org.nd4j nd4j-native ${project.version} - test + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + -Ddtype=float -Xmx8g + + + + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + org.apache.maven.surefire + surefire-junit47 + 2.19.1 + + + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.jcublas.JCublasBackend + org.nd4j.linalg.jcublas.JCublasBackend + + + -Ddtype=float -Xmx6g + + + + diff --git a/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java b/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java index 1d51060ae..44db9d95f 100644 --- a/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java +++ b/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java @@ -42,7 +42,7 @@ import static org.junit.Assert.assertTrue; /** * Created by Alex on 04/07/2016. */ -public class TestNd4jKryoSerialization { +public class TestNd4jKryoSerialization extends BaseND4JTest { private JavaSparkContext sc; diff --git a/nd4j/nd4j-serde/pom.xml b/nd4j/nd4j-serde/pom.xml index d4fc4ff05..aa6f9bed1 100644 --- a/nd4j/nd4j-serde/pom.xml +++ b/nd4j/nd4j-serde/pom.xml @@ -28,10 +28,7 @@ pom nd4j-aeron - nd4j-jackson nd4j-kryo - nd4j-camel-routes - nd4j-gson nd4j-arrow diff --git a/nd4j/nd4j-uberjar/pom.xml b/nd4j/nd4j-uberjar/pom.xml index c3398dea9..84f1c0d4a 100644 --- a/nd4j/nd4j-uberjar/pom.xml +++ b/nd4j/nd4j-uberjar/pom.xml @@ -205,16 +205,6 @@ nd4j-common ${project.version} - - org.nd4j - nd4j-buffer - ${project.version} - - - org.nd4j - nd4j-context - ${project.version} - org.nd4j nd4j-api diff --git a/nd4j/pom.xml b/nd4j/pom.xml index f043d7299..d8e4a58f2 100644 --- a/nd4j/pom.xml +++ b/nd4j/pom.xml @@ -56,13 +56,12 @@ nd4j-jdbc nd4j-serde nd4j-common - nd4j-buffer - nd4j-context nd4j-backends nd4j-parameter-server-parent nd4j-uberjar nd4j-tensorflow nd4j-remote + nd4j-common-tests diff --git a/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala b/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala index 826264b8f..34136f61a 100644 --- a/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala +++ b/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala @@ -18,7 +18,7 @@ package org.nd4s.ops import java.util.{ List, Map, Properties } import org.bytedeco.javacpp.Pointer -import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType, Utf8Buffer } +import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType } import org.nd4j.linalg.api.environment.Nd4jEnvironment import org.nd4j.linalg.api.ndarray.{ INDArray, INDArrayStatistics } import org.nd4j.linalg.api.ops.aggregates.{ Aggregate, Batch } @@ -452,7 +452,7 @@ class FunctionalOpExecutioner extends OpExecutioner { * @param index * @return */ - def getString(buffer: Utf8Buffer, index: Long): String = ??? + def getString(buffer: DataBuffer, index: Long): String = ??? /** * This method returns OpContext which can be used (and reused) to execute custom ops diff --git a/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala index 25e8f374f..d1d760286 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala @@ -122,15 +122,15 @@ class ConstructionTest extends FlatSpec with Matchers { val learning_rate = 0.1 val seed = 7 - val target = Nd4j.createUninitialized(1000) + val target = Nd4j.createUninitialized(DataType.DOUBLE, 1000) val rng = Nd4j.getRandom rng.setSeed(seed) val x1_label1 = Nd4j.randn(3.0, 1.0, target, rng) - val target1 = Nd4j.createUninitialized(1000) + val target1 = Nd4j.createUninitialized(DataType.DOUBLE, 1000) val x2_label1 = Nd4j.randn(2.0, 1.0, target1, rng) - val target2 = Nd4j.createUninitialized(1000) + val target2 = Nd4j.createUninitialized(DataType.DOUBLE, 1000) val x1_label2 = Nd4j.randn(7.0, 1.0, target2, rng) - val target3 = Nd4j.createUninitialized(1000) + val target3 = Nd4j.createUninitialized(DataType.DOUBLE, 1000) val x2_label2 = Nd4j.randn(6.0, 1.0, target3, rng) // np.append, was not able to guess proper method diff --git a/pom.xml b/pom.xml index 3d7082524..8fd933bb3 100644 --- a/pom.xml +++ b/pom.xml @@ -854,5 +854,29 @@ arm + + + + integration-tests + + false + + + + + maven-surefire-plugin + ${maven-surefire-plugin.version} + true + + + true + + + + + +