commit
e4ddf109c3
|
@ -1,47 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.util;
|
|
||||||
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
|
|
||||||
import java.awt.*;
|
|
||||||
import java.net.URI;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Various utilities for webpages and dealing with browsers
|
|
||||||
*/
|
|
||||||
public class WebUtils {
|
|
||||||
|
|
||||||
public static void tryOpenBrowser(String path, Logger log) {
|
|
||||||
try {
|
|
||||||
WebUtils.openBrowser(new URI(path));
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Could not open browser", e);
|
|
||||||
System.out.println("Browser could not be launched automatically.\nUI path: " + path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void openBrowser(URI uri) throws Exception {
|
|
||||||
if (Desktop.isDesktopSupported()) {
|
|
||||||
Desktop.getDesktop().browse(uri);
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException(
|
|
||||||
"Cannot open browser on this platform: Desktop.isDesktopSupported() == false");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -127,7 +127,7 @@ public class BraninFunction {
|
||||||
BraninConfig candidate = (BraninConfig) c.getValue();
|
BraninConfig candidate = (BraninConfig) c.getValue();
|
||||||
|
|
||||||
double score = scoreFunction.score(candidate, null, (Map) null);
|
double score = scoreFunction.score(candidate, null, (Map) null);
|
||||||
System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
|
// System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
|
||||||
|
|
||||||
Thread.sleep(20);
|
Thread.sleep(20);
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ public class TestRandomSearch extends BaseDL4JTest {
|
||||||
runner.execute();
|
runner.execute();
|
||||||
|
|
||||||
|
|
||||||
System.out.println("----- Complete -----");
|
// System.out.println("----- Complete -----");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic;
|
package org.deeplearning4j.arbiter.optimize.genetic;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|
||||||
|
|
||||||
public class TestRandomGenerator implements RandomGenerator {
|
public class TestRandomGenerator implements RandomGenerator {
|
||||||
private final int[] intRandomNumbers;
|
private final int[] intRandomNumbers;
|
||||||
|
@ -63,17 +63,17 @@ public class TestRandomGenerator implements RandomGenerator {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long nextLong() {
|
public long nextLong() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean nextBoolean() {
|
public boolean nextBoolean() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float nextFloat() {
|
public float nextFloat() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -83,6 +83,6 @@ public class TestRandomGenerator implements RandomGenerator {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double nextGaussian() {
|
public double nextGaussian() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
|
||||||
|
@ -26,7 +27,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|
||||||
|
|
||||||
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CrossoverResult crossover() {
|
public CrossoverResult crossover() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
|
||||||
|
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.population.Populati
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ public class RatioCullOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void cullPopulation() {
|
public void cullPopulation() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getCullRatio() {
|
public double getCullRatio() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||||
|
@ -33,7 +34,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ public class GeneticSelectionOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void cullPopulation() {
|
public void cullPopulation() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||||
|
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.Selection
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|
||||||
|
|
||||||
public class SelectionOperatorTests extends BaseDL4JTest {
|
public class SelectionOperatorTests extends BaseDL4JTest {
|
||||||
private class TestSelectionOperator extends SelectionOperator {
|
private class TestSelectionOperator extends SelectionOperator {
|
||||||
|
@ -39,7 +39,7 @@ public class SelectionOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double[] buildNextGenes() {
|
public double[] buildNextGenes() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,7 @@ public class TestComputationGraphSpace extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||||
assertTrue(reluCount > 0);
|
assertTrue(reluCount > 0);
|
||||||
assertTrue(tanhCount > 0);
|
assertTrue(tanhCount > 0);
|
||||||
|
|
||||||
|
|
|
@ -162,7 +162,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||||
List<ResultReference> results = runner.getResults();
|
List<ResultReference> results = runner.getResults();
|
||||||
assertTrue(results.size() > 0);
|
assertTrue(results.size() > 0);
|
||||||
|
|
||||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -165,7 +165,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
||||||
List<ResultReference> results = runner.getResults();
|
List<ResultReference> results = runner.getResults();
|
||||||
assertTrue(results.size() > 0);
|
assertTrue(results.size() > 0);
|
||||||
|
|
||||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
||||||
double l2 = TestUtils.getL2(l);
|
double l2 = TestUtils.getL2(l);
|
||||||
IActivation activation = l.getActivationFn();
|
IActivation activation = l.getActivationFn();
|
||||||
|
|
||||||
System.out.println(lr + "\t" + l2 + "\t" + activation);
|
// System.out.println(lr + "\t" + l2 + "\t" + activation);
|
||||||
|
|
||||||
assertTrue(lr >= 0.3 && lr <= 0.4);
|
assertTrue(lr >= 0.3 && lr <= 0.4);
|
||||||
assertTrue(l2 >= 0.01 && l2 <= 0.1);
|
assertTrue(l2 >= 0.01 && l2 <= 0.1);
|
||||||
|
@ -190,7 +190,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
||||||
ActivationLayer al = als.getValue(d);
|
ActivationLayer al = als.getValue(d);
|
||||||
IActivation activation = al.getActivationFn();
|
IActivation activation = al.getActivationFn();
|
||||||
|
|
||||||
System.out.println(activation);
|
// System.out.println(activation);
|
||||||
|
|
||||||
assertTrue(containsActivationFunction(actFns, activation));
|
assertTrue(containsActivationFunction(actFns, activation));
|
||||||
}
|
}
|
||||||
|
@ -228,7 +228,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
||||||
IActivation activation = el.getActivationFn();
|
IActivation activation = el.getActivationFn();
|
||||||
long nOut = el.getNOut();
|
long nOut = el.getNOut();
|
||||||
|
|
||||||
System.out.println(activation + "\t" + nOut);
|
// System.out.println(activation + "\t" + nOut);
|
||||||
|
|
||||||
assertTrue(containsActivationFunction(actFns, activation));
|
assertTrue(containsActivationFunction(actFns, activation));
|
||||||
assertTrue(nOut >= 10 && nOut <= 20);
|
assertTrue(nOut >= 10 && nOut <= 20);
|
||||||
|
@ -295,7 +295,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
||||||
long nOut = el.getNOut();
|
long nOut = el.getNOut();
|
||||||
double forgetGate = el.getForgetGateBiasInit();
|
double forgetGate = el.getForgetGateBiasInit();
|
||||||
|
|
||||||
System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
|
// System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
|
||||||
|
|
||||||
assertTrue(containsActivationFunction(actFns, activation));
|
assertTrue(containsActivationFunction(actFns, activation));
|
||||||
assertTrue(nOut >= 10 && nOut <= 20);
|
assertTrue(nOut >= 10 && nOut <= 20);
|
||||||
|
|
|
@ -293,8 +293,8 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
|
||||||
assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly
|
assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
|
// System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
|
||||||
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -98,7 +98,8 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
||||||
assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson()));
|
assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson()));
|
||||||
|
|
||||||
FileUtils.writeStringToFile(new File(configPath),configuration.toJson());
|
FileUtils.writeStringToFile(new File(configPath),configuration.toJson());
|
||||||
System.out.println(configuration.toJson());
|
// System.out.println(configuration.toJson());
|
||||||
|
configuration.toJson();
|
||||||
|
|
||||||
log.info("Starting test");
|
log.info("Starting test");
|
||||||
cliRunner.runMain(
|
cliRunner.runMain(
|
||||||
|
|
|
@ -0,0 +1,390 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.regressiontest;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.TestUtils;
|
||||||
|
import org.deeplearning4j.nn.conf.BackpropType;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
import org.deeplearning4j.nn.graph.vertex.impl.MergeVertex;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
|
import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.activations.impl.*;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
|
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMAE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
|
import java.io.DataInputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCustomLayer() throws Exception {
|
||||||
|
|
||||||
|
for (DataType dtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
||||||
|
|
||||||
|
String dtypeName = dtype.toString().toLowerCase();
|
||||||
|
|
||||||
|
File f = Resources.asFile("regression_testing/100b6/CustomLayerExample_100b6_" + dtypeName + ".bin");
|
||||||
|
MultiLayerNetwork.load(f, true);
|
||||||
|
|
||||||
|
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
||||||
|
// net = net.clone();
|
||||||
|
|
||||||
|
DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer();
|
||||||
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||||
|
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
|
||||||
|
assertEquals(new RmsProp(0.95), l0.getIUpdater());
|
||||||
|
|
||||||
|
CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer();
|
||||||
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
|
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
|
||||||
|
assertEquals(new RmsProp(0.95), l1.getIUpdater());
|
||||||
|
|
||||||
|
INDArray outExp;
|
||||||
|
File f2 = Resources
|
||||||
|
.asFile("regression_testing/100b6/CustomLayerExample_Output_100b6_" + dtypeName + ".bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||||
|
outExp = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray in;
|
||||||
|
File f3 = Resources.asFile("regression_testing/100b6/CustomLayerExample_Input_100b6_" + dtypeName + ".bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||||
|
in = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(dtype, in.dataType());
|
||||||
|
assertEquals(dtype, outExp.dataType());
|
||||||
|
assertEquals(dtype, net.params().dataType());
|
||||||
|
assertEquals(dtype, net.getFlattenedGradients().dataType());
|
||||||
|
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
|
||||||
|
|
||||||
|
//System.out.println(Arrays.toString(net.params().data().asFloat()));
|
||||||
|
|
||||||
|
INDArray outAct = net.output(in);
|
||||||
|
assertEquals(dtype, outAct.dataType());
|
||||||
|
|
||||||
|
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
||||||
|
assertEquals(dtype, net.params().dataType());
|
||||||
|
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||||
|
assertTrue(outExp + " vs " + outAct, eq); }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLSTM() throws Exception {
|
||||||
|
|
||||||
|
File f = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_100b6.bin");
|
||||||
|
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
||||||
|
|
||||||
|
LSTM l0 = (LSTM) net.getLayer(0).conf().getLayer();
|
||||||
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||||
|
assertEquals(200, l0.getNOut());
|
||||||
|
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
|
assertEquals(new Adam(0.005), l0.getIUpdater());
|
||||||
|
|
||||||
|
LSTM l1 = (LSTM) net.getLayer(1).conf().getLayer();
|
||||||
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
|
assertEquals(200, l1.getNOut());
|
||||||
|
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
|
assertEquals(new Adam(0.005), l1.getIUpdater());
|
||||||
|
|
||||||
|
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).conf().getLayer();
|
||||||
|
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
||||||
|
assertEquals(77, l2.getNOut());
|
||||||
|
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
||||||
|
assertEquals(new Adam(0.005), l2.getIUpdater());
|
||||||
|
|
||||||
|
assertEquals(BackpropType.TruncatedBPTT, net.getLayerWiseConfigurations().getBackpropType());
|
||||||
|
assertEquals(50, net.getLayerWiseConfigurations().getTbpttBackLength());
|
||||||
|
assertEquals(50, net.getLayerWiseConfigurations().getTbpttFwdLength());
|
||||||
|
|
||||||
|
INDArray outExp;
|
||||||
|
File f2 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Output_100b6.bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||||
|
outExp = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray in;
|
||||||
|
File f3 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Input_100b6.bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||||
|
in = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray outAct = net.output(in);
|
||||||
|
|
||||||
|
assertEquals(outExp, outAct);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVae() throws Exception {
|
||||||
|
|
||||||
|
File f = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_100b6.bin");
|
||||||
|
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
||||||
|
|
||||||
|
VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).conf().getLayer();
|
||||||
|
assertEquals(new ActivationLReLU(), l0.getActivationFn());
|
||||||
|
assertEquals(32, l0.getNOut());
|
||||||
|
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
||||||
|
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
||||||
|
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
|
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
||||||
|
|
||||||
|
INDArray outExp;
|
||||||
|
File f2 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Output_100b6.bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||||
|
outExp = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray in;
|
||||||
|
File f3 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Input_100b6.bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||||
|
in = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray outAct = net.output(in);
|
||||||
|
|
||||||
|
assertEquals(outExp, outAct);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testYoloHouseNumber() throws Exception {
|
||||||
|
|
||||||
|
File f = Resources.asFile("regression_testing/100b6/HouseNumberDetection_100b6.bin");
|
||||||
|
ComputationGraph net = ComputationGraph.load(f, true);
|
||||||
|
|
||||||
|
int nBoxes = 5;
|
||||||
|
int nClasses = 10;
|
||||||
|
|
||||||
|
ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getConfiguration().getVertices()
|
||||||
|
.get("convolution2d_9")).getLayerConf().getLayer();
|
||||||
|
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
||||||
|
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
||||||
|
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
||||||
|
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
|
||||||
|
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
|
||||||
|
|
||||||
|
INDArray outExp;
|
||||||
|
File f2 = Resources.asFile("regression_testing/100b6/HouseNumberDetection_Output_100b6.bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||||
|
outExp = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray in;
|
||||||
|
File f3 = Resources.asFile("regression_testing/100b6/HouseNumberDetection_Input_100b6.bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||||
|
in = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray outAct = net.outputSingle(in);
|
||||||
|
|
||||||
|
boolean eq = outExp.equalsWithEps(outAct.castTo(outExp.dataType()), 1e-3);
|
||||||
|
assertTrue(eq);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSyntheticCNN() throws Exception {
|
||||||
|
|
||||||
|
File f = Resources.asFile("regression_testing/100b6/SyntheticCNN_100b6.bin");
|
||||||
|
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
||||||
|
|
||||||
|
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).conf().getLayer();
|
||||||
|
assertEquals(new ActivationReLU(), l0.getActivationFn());
|
||||||
|
assertEquals(4, l0.getNOut());
|
||||||
|
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
|
assertEquals(new Adam(0.005), l0.getIUpdater());
|
||||||
|
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
||||||
|
assertArrayEquals(new int[]{2, 1}, l0.getStride());
|
||||||
|
assertArrayEquals(new int[]{1, 1}, l0.getDilation());
|
||||||
|
assertArrayEquals(new int[]{0, 0}, l0.getPadding());
|
||||||
|
|
||||||
|
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).conf().getLayer();
|
||||||
|
assertEquals(new ActivationReLU(), l1.getActivationFn());
|
||||||
|
assertEquals(8, l1.getNOut());
|
||||||
|
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
|
assertEquals(new Adam(0.005), l1.getIUpdater());
|
||||||
|
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
||||||
|
assertArrayEquals(new int[]{1, 1}, l1.getStride());
|
||||||
|
assertArrayEquals(new int[]{1, 1}, l1.getDilation());
|
||||||
|
assertArrayEquals(new int[]{0, 0}, l1.getPadding());
|
||||||
|
assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
|
||||||
|
assertEquals(1, l1.getDepthMultiplier());
|
||||||
|
|
||||||
|
SubsamplingLayer l2 = (SubsamplingLayer) net.getLayer(2).conf().getLayer();
|
||||||
|
assertArrayEquals(new int[]{3, 3}, l2.getKernelSize());
|
||||||
|
assertArrayEquals(new int[]{2, 2}, l2.getStride());
|
||||||
|
assertArrayEquals(new int[]{1, 1}, l2.getDilation());
|
||||||
|
assertArrayEquals(new int[]{0, 0}, l2.getPadding());
|
||||||
|
assertEquals(PoolingType.MAX, l2.getPoolingType());
|
||||||
|
|
||||||
|
ZeroPaddingLayer l3 = (ZeroPaddingLayer) net.getLayer(3).conf().getLayer();
|
||||||
|
assertArrayEquals(new int[]{4, 4, 4, 4}, l3.getPadding());
|
||||||
|
|
||||||
|
Upsampling2D l4 = (Upsampling2D) net.getLayer(4).conf().getLayer();
|
||||||
|
assertArrayEquals(new int[]{3, 3}, l4.getSize());
|
||||||
|
|
||||||
|
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).conf().getLayer();
|
||||||
|
assertEquals(new ActivationReLU(), l5.getActivationFn());
|
||||||
|
assertEquals(16, l5.getNOut());
|
||||||
|
assertEquals(new WeightInitXavier(), l5.getWeightInitFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
||||||
|
assertEquals(new Adam(0.005), l5.getIUpdater());
|
||||||
|
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
||||||
|
assertArrayEquals(new int[]{1, 1}, l5.getStride());
|
||||||
|
assertArrayEquals(new int[]{1, 1}, l5.getDilation());
|
||||||
|
assertArrayEquals(new int[]{0, 0}, l5.getPadding());
|
||||||
|
assertEquals(2, l5.getDepthMultiplier());
|
||||||
|
|
||||||
|
SubsamplingLayer l6 = (SubsamplingLayer) net.getLayer(6).conf().getLayer();
|
||||||
|
assertArrayEquals(new int[]{2, 2}, l6.getKernelSize());
|
||||||
|
assertArrayEquals(new int[]{2, 2}, l6.getStride());
|
||||||
|
assertArrayEquals(new int[]{1, 1}, l6.getDilation());
|
||||||
|
assertArrayEquals(new int[]{0, 0}, l6.getPadding());
|
||||||
|
assertEquals(PoolingType.MAX, l6.getPoolingType());
|
||||||
|
|
||||||
|
Cropping2D l7 = (Cropping2D) net.getLayer(7).conf().getLayer();
|
||||||
|
assertArrayEquals(new int[]{3, 3, 2, 2}, l7.getCropping());
|
||||||
|
|
||||||
|
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).conf().getLayer();
|
||||||
|
assertEquals(4, l8.getNOut());
|
||||||
|
assertEquals(new WeightInitXavier(), l8.getWeightInitFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
||||||
|
assertEquals(new Adam(0.005), l8.getIUpdater());
|
||||||
|
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
||||||
|
assertArrayEquals(new int[]{1, 1}, l8.getStride());
|
||||||
|
assertArrayEquals(new int[]{1, 1}, l8.getDilation());
|
||||||
|
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
|
||||||
|
|
||||||
|
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).conf().getLayer();
|
||||||
|
assertEquals(new WeightInitXavier(), l9.getWeightInitFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
||||||
|
assertEquals(new Adam(0.005), l9.getIUpdater());
|
||||||
|
assertEquals(new LossMAE(), l9.getLossFn());
|
||||||
|
|
||||||
|
INDArray outExp;
|
||||||
|
File f2 = Resources.asFile("regression_testing/100b6/SyntheticCNN_Output_100b6.bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||||
|
outExp = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray in;
|
||||||
|
File f3 = Resources.asFile("regression_testing/100b6/SyntheticCNN_Input_100b6.bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||||
|
in = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray outAct = net.output(in);
|
||||||
|
|
||||||
|
//19 layers - CPU vs. GPU difference accumulates notably, but appears to be correct
|
||||||
|
if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")){
|
||||||
|
assertEquals(outExp, outAct);
|
||||||
|
} else {
|
||||||
|
boolean eq = outExp.equalsWithEps(outAct, 0.1);
|
||||||
|
assertTrue(eq);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSyntheticBidirectionalRNNGraph() throws Exception {
|
||||||
|
|
||||||
|
File f = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_100b6.bin");
|
||||||
|
ComputationGraph net = ComputationGraph.load(f, true);
|
||||||
|
|
||||||
|
Bidirectional l0 = (Bidirectional) net.getLayer("rnn1").conf().getLayer();
|
||||||
|
|
||||||
|
LSTM l1 = (LSTM) l0.getFwd();
|
||||||
|
assertEquals(16, l1.getNOut());
|
||||||
|
assertEquals(new ActivationReLU(), l1.getActivationFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
|
|
||||||
|
LSTM l2 = (LSTM) l0.getBwd();
|
||||||
|
assertEquals(16, l2.getNOut());
|
||||||
|
assertEquals(new ActivationReLU(), l2.getActivationFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
||||||
|
|
||||||
|
Bidirectional l3 = (Bidirectional) net.getLayer("rnn2").conf().getLayer();
|
||||||
|
|
||||||
|
SimpleRnn l4 = (SimpleRnn) l3.getFwd();
|
||||||
|
assertEquals(16, l4.getNOut());
|
||||||
|
assertEquals(new ActivationReLU(), l4.getActivationFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l4));
|
||||||
|
|
||||||
|
SimpleRnn l5 = (SimpleRnn) l3.getBwd();
|
||||||
|
assertEquals(16, l5.getNOut());
|
||||||
|
assertEquals(new ActivationReLU(), l5.getActivationFn());
|
||||||
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
||||||
|
|
||||||
|
MergeVertex mv = (MergeVertex) net.getVertex("concat");
|
||||||
|
|
||||||
|
GlobalPoolingLayer gpl = (GlobalPoolingLayer) net.getLayer("pooling").conf().getLayer();
|
||||||
|
assertEquals(PoolingType.MAX, gpl.getPoolingType());
|
||||||
|
assertArrayEquals(new int[]{2}, gpl.getPoolingDimensions());
|
||||||
|
assertTrue(gpl.isCollapseDimensions());
|
||||||
|
|
||||||
|
OutputLayer outl = (OutputLayer) net.getLayer("out").conf().getLayer();
|
||||||
|
assertEquals(3, outl.getNOut());
|
||||||
|
assertEquals(new LossMCXENT(), outl.getLossFn());
|
||||||
|
|
||||||
|
INDArray outExp;
|
||||||
|
File f2 = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_Output_100b6.bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||||
|
outExp = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray in;
|
||||||
|
File f3 = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_Input_100b6.bin");
|
||||||
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||||
|
in = Nd4j.read(dis);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray outAct = net.output(in)[0];
|
||||||
|
|
||||||
|
assertEquals(outExp, outAct);
|
||||||
|
}
|
||||||
|
}
|
|
@ -41,7 +41,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
||||||
|
|
||||||
IGraph<String, String> graph = GraphLoader
|
IGraph<String, String> graph = GraphLoader
|
||||||
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
|
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
|
||||||
System.out.println(graph);
|
// System.out.println(graph);
|
||||||
|
|
||||||
assertEquals(graph.numVertices(), 7);
|
assertEquals(graph.numVertices(), 7);
|
||||||
int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}};
|
int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}};
|
||||||
|
@ -66,7 +66,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
||||||
edgeLineProcessor, vertexFactory, 10, false);
|
edgeLineProcessor, vertexFactory, 10, false);
|
||||||
|
|
||||||
|
|
||||||
System.out.println(graph);
|
// System.out.println(graph);
|
||||||
|
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
List<Edge<String>> edges = graph.getEdgesOut(i);
|
List<Edge<String>> edges = graph.getEdgesOut(i);
|
||||||
|
@ -111,7 +111,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
||||||
Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(),
|
Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(),
|
||||||
edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
|
edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
|
||||||
|
|
||||||
System.out.println(graph);
|
// System.out.println(graph);
|
||||||
|
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
List<Edge<String>> edges = graph.getEdgesOut(i);
|
List<Edge<String>> edges = graph.getEdgesOut(i);
|
||||||
|
|
|
@ -71,7 +71,7 @@ public class TestGraphLoadingWeighted extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(graph);
|
// System.out.println(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -220,7 +220,7 @@ public class TestGraph extends BaseDL4JTest {
|
||||||
sum += transitionProb[i][j];
|
sum += transitionProb[i][j];
|
||||||
for (int j = 0; j < transitionProb[i].length; j++)
|
for (int j = 0; j < transitionProb[i].length; j++)
|
||||||
transitionProb[i][j] /= sum;
|
transitionProb[i][j] /= sum;
|
||||||
System.out.println(Arrays.toString(transitionProb[i]));
|
// System.out.println(Arrays.toString(transitionProb[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
//Check that transition probs are essentially correct (within bounds of random variation)
|
//Check that transition probs are essentially correct (within bounds of random variation)
|
||||||
|
|
|
@ -145,8 +145,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
||||||
|
|
||||||
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
||||||
fail(msg);
|
fail(msg);
|
||||||
else
|
// else
|
||||||
System.out.println(msg);
|
// System.out.println(msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,10 +333,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
||||||
|
|
||||||
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
||||||
fail(msg);
|
fail(msg);
|
||||||
else
|
// else
|
||||||
System.out.println(msg);
|
// System.out.println(msg);
|
||||||
}
|
}
|
||||||
System.out.println();
|
// System.out.println();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
for (int i = 0; i < 7; i++) {
|
for (int i = 0; i < 7; i++) {
|
||||||
INDArray vector = deepWalk.getVertexVector(i);
|
INDArray vector = deepWalk.getVertexVector(i);
|
||||||
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
||||||
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
|
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
|
||||||
|
@ -77,11 +77,11 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
for (int t = 0; t < 5; t++) {
|
for (int t = 0; t < 5; t++) {
|
||||||
iter.reset();
|
iter.reset();
|
||||||
deepWalk.fit(iter);
|
deepWalk.fit(iter);
|
||||||
System.out.println("--------------------");
|
// System.out.println("--------------------");
|
||||||
for (int i = 0; i < 7; i++) {
|
for (int i = 0; i < 7; i++) {
|
||||||
INDArray vector = deepWalk.getVertexVector(i);
|
INDArray vector = deepWalk.getVertexVector(i);
|
||||||
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
||||||
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -160,7 +160,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
double sim = deepWalk.similarity(i, nearestTo);
|
double sim = deepWalk.similarity(i, nearestTo);
|
||||||
System.out.println(i + "\t" + nearestTo + "\t" + sim);
|
// System.out.println(i + "\t" + nearestTo + "\t" + sim);
|
||||||
assertTrue(sim <= minSimNearest);
|
assertTrue(sim <= minSimNearest);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -211,7 +211,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
Graph<String, String> graph = GraphLoader
|
Graph<String, String> graph = GraphLoader
|
||||||
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
|
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
|
||||||
|
|
||||||
System.out.println(graph);
|
// System.out.println(graph);
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -229,11 +229,13 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
|
|
||||||
//Calculate similarity(0,i)
|
//Calculate similarity(0,i)
|
||||||
for (int i = 0; i < nVertices; i++) {
|
for (int i = 0; i < nVertices; i++) {
|
||||||
System.out.println(deepWalk.similarity(0, i));
|
// System.out.println(deepWalk.similarity(0, i));
|
||||||
|
deepWalk.similarity(0, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < nVertices; i++)
|
for (int i = 0; i < nVertices; i++)
|
||||||
System.out.println(deepWalk.getVertexVector(i));
|
// System.out.println(deepWalk.getVertexVector(i));
|
||||||
|
deepWalk.getVertexVector(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test(timeout = 60000L)
|
||||||
|
|
|
@ -38,9 +38,11 @@ public class TestGraphHuffman extends BaseDL4JTest {
|
||||||
|
|
||||||
gh.buildTree(vertexDegrees);
|
gh.buildTree(vertexDegrees);
|
||||||
|
|
||||||
for (int i = 0; i < 7; i++)
|
for (int i = 0; i < 7; i++) {
|
||||||
System.out.println(i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
|
String s = i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
|
||||||
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i)));
|
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i));
|
||||||
|
// System.out.println(s);
|
||||||
|
}
|
||||||
|
|
||||||
int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5};
|
int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5};
|
||||||
for (int i = 0; i < vertexDegrees.length; i++) {
|
for (int i = 0; i < vertexDegrees.length; i++) {
|
||||||
|
|
|
@ -3,6 +3,7 @@ package org.deeplearning4j.util;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
@ -121,7 +122,7 @@ public class DL4JModelValidator {
|
||||||
}
|
}
|
||||||
|
|
||||||
try{
|
try{
|
||||||
MultiLayerConfiguration.fromJson(config);
|
ComputationGraphConfiguration.fromJson(config);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
return ValidationResult.builder()
|
return ValidationResult.builder()
|
||||||
.formatType("ComputationGraph")
|
.formatType("ComputationGraph")
|
||||||
|
|
|
@ -79,8 +79,9 @@ public class ParameterServerParallelWrapperTest extends BaseDL4JTest {
|
||||||
model.init();
|
model.init();
|
||||||
|
|
||||||
ParallelWrapper parameterServerParallelWrapper =
|
ParallelWrapper parameterServerParallelWrapper =
|
||||||
new ParallelWrapper.Builder(model).trainerFactory(new ParameterServerTrainerContext())
|
new ParallelWrapper.Builder(model)
|
||||||
.workers(Runtime.getRuntime().availableProcessors())
|
.workers(Math.min(4, Runtime.getRuntime().availableProcessors()))
|
||||||
|
.trainerFactory(new ParameterServerTrainerContext())
|
||||||
.reportScoreAfterAveraging(true).prefetchBuffer(3).build();
|
.reportScoreAfterAveraging(true).prefetchBuffer(3).build();
|
||||||
parameterServerParallelWrapper.fit(mnistTrain);
|
parameterServerParallelWrapper.fit(mnistTrain);
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,7 @@ public class SparkWord2VecTest extends BaseDL4JTest {
|
||||||
public void call(ExportContainer<VocabWord> v) throws Exception {
|
public void call(ExportContainer<VocabWord> v) throws Exception {
|
||||||
assertNotNull(v.getElement());
|
assertNotNull(v.getElement());
|
||||||
assertNotNull(v.getArray());
|
assertNotNull(v.getArray());
|
||||||
System.out.println(v.getElement() + " - " + v.getArray());
|
// System.out.println(v.getElement() + " - " + v.getArray());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
@ -119,7 +119,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MSE).build())
|
.lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
||||||
|
@ -155,7 +155,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
|
@ -69,7 +69,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
@ -120,7 +120,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MSE).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MSE).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
||||||
|
@ -158,7 +158,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
||||||
|
@ -203,7 +203,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
||||||
|
@ -238,7 +238,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
|
@ -59,7 +59,7 @@ public class TestShuffleExamples extends BaseSparkTest {
|
||||||
int totalExampleCount = 0;
|
int totalExampleCount = 0;
|
||||||
for (DataSet ds : shuffledList) {
|
for (DataSet ds : shuffledList) {
|
||||||
totalExampleCount += ds.getFeatures().length();
|
totalExampleCount += ds.getFeatures().length();
|
||||||
System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
|
// System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
|
||||||
|
|
||||||
assertEquals(ds.getFeatures(), ds.getLabels());
|
assertEquals(ds.getFeatures(), ds.getLabels());
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ public class TestExport extends BaseSparkTest {
|
||||||
for (File file : files) {
|
for (File file : files) {
|
||||||
if (!file.getPath().endsWith(".bin"))
|
if (!file.getPath().endsWith(".bin"))
|
||||||
continue;
|
continue;
|
||||||
System.out.println(file);
|
// System.out.println(file);
|
||||||
DataSet ds = new DataSet();
|
DataSet ds = new DataSet();
|
||||||
ds.load(file);
|
ds.load(file);
|
||||||
assertEquals(minibatchSize, ds.numExamples());
|
assertEquals(minibatchSize, ds.numExamples());
|
||||||
|
@ -144,7 +144,7 @@ public class TestExport extends BaseSparkTest {
|
||||||
for (File file : files) {
|
for (File file : files) {
|
||||||
if (!file.getPath().endsWith(".bin"))
|
if (!file.getPath().endsWith(".bin"))
|
||||||
continue;
|
continue;
|
||||||
System.out.println(file);
|
// System.out.println(file);
|
||||||
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
|
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
|
||||||
ds.load(file);
|
ds.load(file);
|
||||||
assertEquals(minibatchSize, ds.getFeatures(0).size(0));
|
assertEquals(minibatchSize, ds.getFeatures(0).size(0));
|
||||||
|
|
|
@ -92,9 +92,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
||||||
|
|
||||||
int[][] colorCountsByPartition = new int[3][2];
|
int[][] colorCountsByPartition = new int[3][2];
|
||||||
for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) {
|
for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) {
|
||||||
System.out.println(val);
|
// System.out.println(val);
|
||||||
Integer partition = hbp.getPartition(val._1());
|
Integer partition = hbp.getPartition(val._1());
|
||||||
System.out.println(partition);
|
// System.out.println(partition);
|
||||||
|
|
||||||
if (val._2().equals("red"))
|
if (val._2().equals("red"))
|
||||||
colorCountsByPartition[partition][0] += 1;
|
colorCountsByPartition[partition][0] += 1;
|
||||||
|
@ -102,9 +102,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
||||||
colorCountsByPartition[partition][1] += 1;
|
colorCountsByPartition[partition][1] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < 3; i++) {
|
// for (int i = 0; i < 3; i++) {
|
||||||
System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||||
}
|
// }
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
// avg red per partition : 2.33
|
// avg red per partition : 2.33
|
||||||
assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4);
|
assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4);
|
||||||
|
@ -178,12 +178,12 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
||||||
colorCountsByPartition[partition][1] += 1;
|
colorCountsByPartition[partition][1] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < numPartitions; i++) {
|
// for (int i = 0; i < numPartitions; i++) {
|
||||||
System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
System.out.println("Ideal red # per partition: " + avgRed);
|
// System.out.println("Ideal red # per partition: " + avgRed);
|
||||||
System.out.println("Ideal blue # per partition: " + avgBlue);
|
// System.out.println("Ideal blue # per partition: " + avgBlue);
|
||||||
|
|
||||||
for (int i = 0; i < numPartitions; i++) {
|
for (int i = 0; i < numPartitions; i++) {
|
||||||
// avg red per partition : 2.33
|
// avg red per partition : 2.33
|
||||||
|
|
|
@ -115,7 +115,7 @@ public class TestSparkComputationGraph extends BaseSparkTest {
|
||||||
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
|
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
|
||||||
|
|
||||||
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
|
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
|
||||||
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(1)));
|
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5)));
|
||||||
|
|
||||||
JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
|
JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
|
||||||
scg.fitMultiDataSet(rdd);
|
scg.fitMultiDataSet(rdd);
|
||||||
|
|
|
@ -31,8 +31,11 @@ import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
@ -45,8 +48,24 @@ import static org.junit.Assert.assertTrue;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 120000L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDefaultFPDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testEvaluationSimple() throws Exception {
|
public void testEvaluationSimple() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
for( int evalWorkers : new int[]{1, 4, 8}) {
|
for( int evalWorkers : new int[]{1, 4, 8}) {
|
||||||
//Simple test to validate DL4J issue 4099 is fixed...
|
//Simple test to validate DL4J issue 4099 is fixed...
|
||||||
|
@ -75,18 +94,18 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
||||||
//----------------------------------
|
//----------------------------------
|
||||||
//Create network configuration and conduct network training
|
//Create network configuration and conduct network training
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.activation(Activation.LEAKYRELU)
|
.activation(Activation.LEAKYRELU)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new Nesterovs(0.02, 0.9))
|
.updater(new Adam(1e-3))
|
||||||
.l2(1e-4)
|
.l2(1e-5)
|
||||||
.list()
|
.list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
|
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
|
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
|
||||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||||
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
|
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
|
||||||
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options
|
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options
|
||||||
|
|
|
@ -333,15 +333,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
sparkNet.fit(rdd);
|
sparkNet.fit(rdd);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||||
|
sparkNet.getSparkTrainingStats().statsAsString();
|
||||||
|
|
||||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
|
|
||||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||||
System.out.println("Initial (Spark) params: "
|
// System.out.println("Initial (Spark) params: "
|
||||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||||
assertEquals(initialParams, initialSparkParams);
|
assertEquals(initialParams, initialSparkParams);
|
||||||
assertNotEquals(initialParams, finalParams);
|
assertNotEquals(initialParams, finalParams);
|
||||||
assertEquals(finalParams, finalSparkParams);
|
assertEquals(finalParams, finalSparkParams);
|
||||||
|
@ -405,15 +406,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
sparkNet.fit(rdd);
|
sparkNet.fit(rdd);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||||
|
sparkNet.getSparkTrainingStats().statsAsString();
|
||||||
|
|
||||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
|
|
||||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||||
System.out.println("Initial (Spark) params: "
|
// System.out.println("Initial (Spark) params: "
|
||||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||||
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
||||||
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
||||||
|
|
||||||
|
@ -478,18 +480,19 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
sparkNet.fit(rdd);
|
sparkNet.fit(rdd);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||||
|
sparkNet.getSparkTrainingStats().statsAsString();
|
||||||
|
|
||||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
// executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
|
// executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
|
||||||
|
|
||||||
float[] fp = finalParams.data().asFloat();
|
float[] fp = finalParams.data().asFloat();
|
||||||
float[] fps = finalSparkParams.data().asFloat();
|
float[] fps = finalSparkParams.data().asFloat();
|
||||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||||
System.out.println("Initial (Spark) params: "
|
// System.out.println("Initial (Spark) params: "
|
||||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||||
System.out.println("Final (Local) params: " + Arrays.toString(fp));
|
// System.out.println("Final (Local) params: " + Arrays.toString(fp));
|
||||||
System.out.println("Final (Spark) params: " + Arrays.toString(fps));
|
// System.out.println("Final (Spark) params: " + Arrays.toString(fps));
|
||||||
|
|
||||||
assertEquals(initialParams, initialSparkParams);
|
assertEquals(initialParams, initialSparkParams);
|
||||||
assertNotEquals(initialParams, finalParams);
|
assertNotEquals(initialParams, finalParams);
|
||||||
|
@ -551,14 +554,15 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
sparkNet.fit(rdd);
|
sparkNet.fit(rdd);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||||
|
sparkNet.getSparkTrainingStats().statsAsString();
|
||||||
|
|
||||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
|
|
||||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||||
System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
|
// System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||||
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
||||||
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ public class TestJsonYaml {
|
||||||
String json = tm.toJson();
|
String json = tm.toJson();
|
||||||
String yaml = tm.toYaml();
|
String yaml = tm.toYaml();
|
||||||
|
|
||||||
System.out.println(json);
|
// System.out.println(json);
|
||||||
|
|
||||||
TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json);
|
TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json);
|
||||||
TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml);
|
TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml);
|
||||||
|
|
|
@ -389,7 +389,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
||||||
for (EventStats e : workerFitStats) {
|
for (EventStats e : workerFitStats) {
|
||||||
ExampleCountEventStats eces = (ExampleCountEventStats) e;
|
ExampleCountEventStats eces = (ExampleCountEventStats) e;
|
||||||
System.out.println(eces.getTotalExampleCount());
|
// System.out.println(eces.getTotalExampleCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (EventStats e : workerFitStats) {
|
for (EventStats e : workerFitStats) {
|
||||||
|
@ -457,7 +457,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
assertNotEquals(paramsBefore, paramsAfter);
|
assertNotEquals(paramsBefore, paramsAfter);
|
||||||
|
|
||||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||||
System.out.println(stats.statsAsString());
|
// System.out.println(stats.statsAsString());
|
||||||
|
stats.statsAsString();
|
||||||
|
|
||||||
sparkNet.getTrainingMaster().deleteTempFiles(sc);
|
sparkNet.getTrainingMaster().deleteTempFiles(sc);
|
||||||
}
|
}
|
||||||
|
@ -483,7 +484,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -527,7 +528,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||||
|
|
||||||
//Expect
|
//Expect
|
||||||
System.out.println(stats.statsAsString());
|
// System.out.println(stats.statsAsString());
|
||||||
|
stats.statsAsString();
|
||||||
assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size());
|
assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size());
|
||||||
|
|
||||||
List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
||||||
|
@ -566,8 +568,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||||
System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
|
// System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -610,7 +612,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
assertNotEquals(paramsBefore, paramsAfter);
|
assertNotEquals(paramsBefore, paramsAfter);
|
||||||
|
|
||||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||||
System.out.println(stats.statsAsString());
|
// System.out.println(stats.statsAsString());
|
||||||
|
stats.statsAsString();
|
||||||
|
|
||||||
//Same thing, buf for MultiDataSet objects:
|
//Same thing, buf for MultiDataSet objects:
|
||||||
config = new Configuration();
|
config = new Configuration();
|
||||||
|
@ -631,7 +634,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
assertNotEquals(paramsBefore, paramsAfter);
|
assertNotEquals(paramsBefore, paramsAfter);
|
||||||
|
|
||||||
stats = sparkNet.getSparkTrainingStats();
|
stats = sparkNet.getSparkTrainingStats();
|
||||||
System.out.println(stats.statsAsString());
|
// System.out.println(stats.statsAsString());
|
||||||
|
stats.statsAsString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -730,13 +734,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
for (int avgFreq : new int[] {1, 5, 10}) {
|
for (int avgFreq : new int[] {1, 5, 10}) {
|
||||||
System.out.println("--- Avg freq " + avgFreq + " ---");
|
// System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||||
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(),
|
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(),
|
||||||
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
||||||
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
||||||
.repartionData(Repartition.Always).build());
|
.repartionData(Repartition.Always).build());
|
||||||
|
|
||||||
sparkNet.setListeners(new ScoreIterationListener(1));
|
sparkNet.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -778,13 +782,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
.setOutputs("1").build();
|
.setOutputs("1").build();
|
||||||
|
|
||||||
for (int avgFreq : new int[] {1, 5, 10}) {
|
for (int avgFreq : new int[] {1, 5, 10}) {
|
||||||
System.out.println("--- Avg freq " + avgFreq + " ---");
|
// System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||||
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(),
|
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(),
|
||||||
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
||||||
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
||||||
.repartionData(Repartition.Always).build());
|
.repartionData(Repartition.Always).build());
|
||||||
|
|
||||||
sparkNet.setListeners(new ScoreIterationListener(1));
|
sparkNet.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> rdd = sc.parallelize(list);
|
JavaRDD<DataSet> rdd = sc.parallelize(list);
|
||||||
|
|
||||||
|
|
|
@ -107,7 +107,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
|
||||||
expectedStatNames.addAll(c);
|
expectedStatNames.addAll(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(expectedStatNames);
|
// System.out.println(expectedStatNames);
|
||||||
|
|
||||||
|
|
||||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||||
|
@ -119,7 +119,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
String statsAsString = stats.statsAsString();
|
String statsAsString = stats.statsAsString();
|
||||||
System.out.println(statsAsString);
|
// System.out.println(statsAsString);
|
||||||
assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat
|
assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ public class TestTimeSource {
|
||||||
long systemTime = System.currentTimeMillis();
|
long systemTime = System.currentTimeMillis();
|
||||||
long ntpTime = timeSource.currentTimeMillis();
|
long ntpTime = timeSource.currentTimeMillis();
|
||||||
long offset = ntpTime - systemTime;
|
long offset = ntpTime - systemTime;
|
||||||
System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
|
// System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||||
Thread.sleep(500);
|
Thread.sleep(500);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,7 +49,7 @@ public class TestTimeSource {
|
||||||
long systemTime = System.currentTimeMillis();
|
long systemTime = System.currentTimeMillis();
|
||||||
long ntpTime = timeSource.currentTimeMillis();
|
long ntpTime = timeSource.currentTimeMillis();
|
||||||
long offset = ntpTime - systemTime;
|
long offset = ntpTime - systemTime;
|
||||||
System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
|
// System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||||
assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next
|
assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next
|
||||||
Thread.sleep(500);
|
Thread.sleep(500);
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,7 +87,7 @@ public class TestListeners extends BaseSparkTest {
|
||||||
net.fit(rdd);
|
net.fit(rdd);
|
||||||
|
|
||||||
List<String> sessions = ss.listSessionIDs();
|
List<String> sessions = ss.listSessionIDs();
|
||||||
System.out.println("Sessions: " + sessions);
|
// System.out.println("Sessions: " + sessions);
|
||||||
assertEquals(1, sessions.size());
|
assertEquals(1, sessions.size());
|
||||||
|
|
||||||
String sid = sessions.get(0);
|
String sid = sessions.get(0);
|
||||||
|
@ -95,15 +95,15 @@ public class TestListeners extends BaseSparkTest {
|
||||||
List<String> typeIDs = ss.listTypeIDsForSession(sid);
|
List<String> typeIDs = ss.listTypeIDsForSession(sid);
|
||||||
List<String> workers = ss.listWorkerIDsForSession(sid);
|
List<String> workers = ss.listWorkerIDsForSession(sid);
|
||||||
|
|
||||||
System.out.println(sid + "\t" + typeIDs + "\t" + workers);
|
// System.out.println(sid + "\t" + typeIDs + "\t" + workers);
|
||||||
|
|
||||||
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
|
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
|
||||||
System.out.println(lastUpdates);
|
// System.out.println(lastUpdates);
|
||||||
|
|
||||||
System.out.println("Static info:");
|
// System.out.println("Static info:");
|
||||||
for (String wid : workers) {
|
for (String wid : workers) {
|
||||||
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
|
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
|
||||||
System.out.println(sid + "\t" + wid);
|
// System.out.println(sid + "\t" + wid);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(1, typeIDs.size());
|
assertEquals(1, typeIDs.size());
|
||||||
|
|
|
@ -63,7 +63,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
assertEquals(10, rdd2.partitions().size());
|
assertEquals(10, rdd2.partitions().size());
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
||||||
System.out.println("Partition " + i + " size: " + partition.size());
|
// System.out.println("Partition " + i + " size: " + partition.size());
|
||||||
assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
|
assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -170,7 +170,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
||||||
|
|
||||||
System.out.println(partitionCounts);
|
// System.out.println(partitionCounts);
|
||||||
|
|
||||||
List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList(
|
List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList(
|
||||||
new Tuple2<>(0,29),
|
new Tuple2<>(0,29),
|
||||||
|
@ -185,7 +185,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112);
|
JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112);
|
||||||
List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
||||||
System.out.println(partitionCountsAfter);
|
// System.out.println(partitionCountsAfter);
|
||||||
|
|
||||||
for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){
|
for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){
|
||||||
assertEquals(2, (int)t2._2());
|
assertEquals(2, (int)t2._2());
|
||||||
|
@ -219,8 +219,8 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("min: " + min + "\t@\t" + minIdx);
|
// System.out.println("min: " + min + "\t@\t" + minIdx);
|
||||||
System.out.println("max: " + max + "\t@\t" + maxIdx);
|
// System.out.println("max: " + max + "\t@\t" + maxIdx);
|
||||||
|
|
||||||
assertEquals(1, min);
|
assertEquals(1, min);
|
||||||
assertEquals(2, max);
|
assertEquals(2, max);
|
||||||
|
@ -244,7 +244,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
||||||
System.out.println("Partition " + i + " size: " + partition.size());
|
// System.out.println("Partition " + i + " size: " + partition.size());
|
||||||
assertTrue(partition.size() >= 90 && partition.size() <= 110);
|
assertTrue(partition.size() >= 90 && partition.size() <= 110);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ project(mkldnn-download NONE)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
ExternalProject_Add(mkldnn
|
ExternalProject_Add(mkldnn
|
||||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||||
GIT_TAG v1.1.3
|
GIT_TAG v1.2
|
||||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
|
|
|
@ -999,14 +999,14 @@ namespace nd4j {
|
||||||
* set new order and shape in case of suitable array length (in-place operation)
|
* set new order and shape in case of suitable array length (in-place operation)
|
||||||
* order - order to set
|
* order - order to set
|
||||||
* shape - shape to set
|
* shape - shape to set
|
||||||
*
|
* copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping
|
||||||
* if there was permute applied before or there are weird strides, then new buffer is allocated for array
|
* if there was permute applied before or there are weird strides, then new buffer is allocated for array
|
||||||
*/
|
*/
|
||||||
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape);
|
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||||
bool reshapei(const char order, const std::vector<Nd4jLong>& shape);
|
bool reshapei(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||||
|
|
||||||
bool reshapei(const std::initializer_list<Nd4jLong>& shape);
|
bool reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||||
bool reshapei(const std::vector<Nd4jLong>& shape);
|
bool reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* creates new array with corresponding order and shape, new array will point on _buffer of this array
|
* creates new array with corresponding order and shape, new array will point on _buffer of this array
|
||||||
|
@ -1015,8 +1015,8 @@ namespace nd4j {
|
||||||
*
|
*
|
||||||
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
|
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
|
||||||
*/
|
*/
|
||||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const &;
|
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) const &;
|
||||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) &&;
|
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) &&;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* calculate strides and set given order
|
* calculate strides and set given order
|
||||||
|
|
|
@ -501,7 +501,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
||||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto cdata = data + offsets[e];
|
auto cdata = data + offsets[e];
|
||||||
if (dataType == DataType::UTF16) {
|
if (dataType == DataType::UTF16) {
|
||||||
unicode::utf8to16(string[e], cdata, std::char_traits<char>::length(string[e]));
|
unicode::utf8to16(string[e], cdata, std::char_traits<char>::length(string[e]));
|
||||||
|
@ -568,7 +568,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::stri
|
||||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto cdata = data + offsets[e];
|
auto cdata = data + offsets[e];
|
||||||
if (dataType == DataType::UTF16) {
|
if (dataType == DataType::UTF16) {
|
||||||
unicode::utf8to16(string[e].data(), cdata, string[e].size());
|
unicode::utf8to16(string[e].data(), cdata, string[e].size());
|
||||||
|
@ -635,7 +635,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16s
|
||||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto cdata = data + offsets[e];
|
auto cdata = data + offsets[e];
|
||||||
if (dtype == DataType::UTF16) {
|
if (dtype == DataType::UTF16) {
|
||||||
memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t));
|
memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t));
|
||||||
|
@ -701,7 +701,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
||||||
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto cdata = data + offsets[e];
|
auto cdata = data + offsets[e];
|
||||||
if (dtype == DataType::UTF16) {
|
if (dtype == DataType::UTF16) {
|
||||||
memcpy(cdata, string[e], std::char_traits<char16_t>::length(string[e]) * sizeof(uint16_t));
|
memcpy(cdata, string[e], std::char_traits<char16_t>::length(string[e]) * sizeof(uint16_t));
|
||||||
|
@ -767,7 +767,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
|
||||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto cdata = data + offsets[e];
|
auto cdata = data + offsets[e];
|
||||||
if (dtype == DataType::UTF16) {
|
if (dtype == DataType::UTF16) {
|
||||||
unicode::utf32to16(string[e].data(), cdata, string[e].size());
|
unicode::utf32to16(string[e].data(), cdata, string[e].size());
|
||||||
|
@ -833,7 +833,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
||||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto cdata = data + offsets[e];
|
auto cdata = data + offsets[e];
|
||||||
if (dtype == DataType::UTF16) {
|
if (dtype == DataType::UTF16) {
|
||||||
unicode::utf32to16(string[e], cdata, std::char_traits<char32_t>::length(string[e]));
|
unicode::utf32to16(string[e], cdata, std::char_traits<char32_t>::length(string[e]));
|
||||||
|
@ -1197,8 +1197,8 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
|
||||||
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
|
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
|
||||||
}
|
}
|
||||||
|
|
||||||
// memcpy is allowed only for same order && same ews (being equal to 1)
|
// memcpy is allowed only for same order c && same ews (being equal to 1)
|
||||||
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
||||||
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
||||||
else {
|
else {
|
||||||
NDArray::prepareSpecialUse({this}, {&other});
|
NDArray::prepareSpecialUse({this}, {&other});
|
||||||
|
@ -1569,20 +1569,25 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector<int>& dimensions) cons
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::printShapeInfo(const char * msg) const {
|
void NDArray::printShapeInfo(const char * msg) const {
|
||||||
//shape::printShapeInfo(_shapeInfo);
|
|
||||||
if (msg == nullptr)
|
int rank = shape::rank(_shapeInfo);
|
||||||
shape::printShapeInfoLinear(_shapeInfo);
|
int lim = shape::shapeInfoLength(rank);
|
||||||
else {
|
|
||||||
int rank = shape::rank(_shapeInfo);
|
if(msg != nullptr)
|
||||||
int lim = shape::shapeInfoLength(rank);
|
printf("shapeInfo %s: [", msg);
|
||||||
printf("%s: [", msg);
|
else
|
||||||
for (int i = 0; i < shape::shapeInfoLength(rank); i++) {
|
printf("shapeInfo: [");
|
||||||
printf("%lld", (long long) _shapeInfo[i]);
|
|
||||||
if (i < lim - 1)
|
printf("%i, ", rank);
|
||||||
printf(", ");
|
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
|
||||||
}
|
if(i == rank + 1)
|
||||||
printf("]\n");
|
printf(" ");
|
||||||
|
printf("%lld,", _shapeInfo[i]);
|
||||||
}
|
}
|
||||||
|
printf(" %lld,", shape::type(_shapeInfo));
|
||||||
|
printf("%lld,", shape::elementWiseStride(_shapeInfo));
|
||||||
|
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
|
||||||
|
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1855,19 +1860,19 @@ void NDArray::updateStrides(const char order) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// set new order and shape in case of suitable array length
|
// set new order and shape in case of suitable array length
|
||||||
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape) {
|
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||||
std::vector<Nd4jLong> vShape(shape);
|
std::vector<Nd4jLong> vShape(shape);
|
||||||
return reshapei(order, vShape);
|
return reshapei(order, vShape, copyToNewBuff);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape) {
|
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||||
return reshapei('c', shape);
|
return reshapei(ordering(), shape, copyToNewBuff);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape) {
|
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||||
return reshapei('c', shape);
|
return reshapei(ordering(), shape, copyToNewBuff);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1918,18 +1923,18 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
|
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
|
||||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const & {
|
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) const & {
|
||||||
|
|
||||||
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
|
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
|
||||||
newArr.reshapei(order, shape);
|
newArr.reshapei(order, shape, copyToNewBuff);
|
||||||
|
|
||||||
return newArr;
|
return newArr;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) && {
|
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) && {
|
||||||
|
|
||||||
this->reshapei(order, shape);
|
this->reshapei(order, shape, copyToNewBuff);
|
||||||
return std::move(*this);
|
return std::move(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1971,7 +1976,7 @@ bool NDArray::permutei(const std::initializer_list<int>& dimensions) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
bool NDArray::permutei(const std::vector<int>& dimensions) {
|
bool NDArray::permutei(const std::vector<int>& dimensions) {
|
||||||
return permutei(dimensions.data(), dimensions.size());
|
return permutei(dimensions.data(), rankOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1993,7 +1998,7 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
|
||||||
for (int e = 0; e < dimensions.size(); e++)
|
for (int e = 0; e < dimensions.size(); e++)
|
||||||
ivec[e] = dimensions[e];
|
ivec[e] = dimensions[e];
|
||||||
|
|
||||||
return permutei(ivec.data(), ivec.size());
|
return permutei(ivec.data(), rankOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2029,9 +2034,8 @@ NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::permute(const std::vector<int>& dimensions) const &{
|
NDArray NDArray::permute(const std::vector<int>& dimensions) const &{
|
||||||
auto data = dimensions.data();
|
|
||||||
auto size = dimensions.size();
|
return permute(dimensions.data(), rankOf());
|
||||||
return permute(data, size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2043,7 +2047,8 @@ NDArray NDArray::permute(const std::vector<int>& dimensions) && {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & {
|
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & {
|
||||||
return permute(dimensions.data(), dimensions.size());
|
|
||||||
|
return permute(dimensions.data(), rankOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2106,12 +2111,12 @@ void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& targe
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const {
|
void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const {
|
||||||
permute(dimensions.data(), dimensions.size(), target);
|
permute(dimensions.data(), rankOf(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const {
|
void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const {
|
||||||
permute(dimensions.data(), dimensions.size(), target);
|
permute(dimensions.data(), rankOf(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2362,7 +2367,7 @@ NDArray NDArray::asS() const {
|
||||||
const auto inData = bufferAsT<int8_t>() + offsetsLength;
|
const auto inData = bufferAsT<int8_t>() + offsetsLength;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (int e = start; e < stop; e += increment) {
|
for (int e = start; e < stop; e++) {
|
||||||
auto cdata = outData + offsets[e];
|
auto cdata = outData + offsets[e];
|
||||||
auto end = nInputoffsets[e + 1];
|
auto end = nInputoffsets[e + 1];
|
||||||
auto idata = inData + nInputoffsets[e];
|
auto idata = inData + nInputoffsets[e];
|
||||||
|
@ -3221,7 +3226,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LI
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// set new order and shape in case of suitable array length
|
// set new order and shape in case of suitable array length
|
||||||
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape, const bool copyToNewBuff) {
|
||||||
|
|
||||||
// check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
|
// check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
|
||||||
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
|
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
|
||||||
|
@ -3293,19 +3298,15 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
||||||
Nd4jLong *shapeInfoNew;
|
Nd4jLong *shapeInfoNew;
|
||||||
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
||||||
|
|
||||||
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
|
bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew);
|
||||||
|
|
||||||
// we can do this only if there was no permute applied, or there are no weird strides
|
|
||||||
if (canReshape) {
|
if (canReshape) {
|
||||||
if(ordering() == 'c' && order == 'f')
|
|
||||||
throw std::invalid_argument("NDArray::reshapei(order, shape): in case of reshapeC it doesn't make sense to reshape from c order to f order !");
|
|
||||||
|
|
||||||
shape::setEws(shapeInfoNew, arrLength);
|
|
||||||
setShapeInfo(shapeInfoNew);
|
setShapeInfo(shapeInfoNew);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
NDArray temp(order, shape, dataType(), getContext());
|
NDArray temp(order, shape, dataType(), getContext());
|
||||||
this->applyTransform(transform::Assign, temp, nullptr);
|
if(copyToNewBuff)
|
||||||
|
this->applyTransform(transform::Assign, temp, nullptr);
|
||||||
*this = std::move(temp);
|
*this = std::move(temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3465,7 +3466,7 @@ NDArray NDArray::dup(const char newOrder) const {
|
||||||
std::vector<std::string> strings(lengthOf());
|
std::vector<std::string> strings(lengthOf());
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
strings[i] = std::move(this->e<std::string>(i));
|
strings[i] = std::move(this->e<std::string>(i));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -3478,7 +3479,7 @@ NDArray NDArray::dup(const char newOrder) const {
|
||||||
std::vector<std::u16string> strings(lengthOf());
|
std::vector<std::u16string> strings(lengthOf());
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
strings[i] = std::move(this->e<std::u16string>(i));
|
strings[i] = std::move(this->e<std::u16string>(i));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -3490,7 +3491,7 @@ NDArray NDArray::dup(const char newOrder) const {
|
||||||
|
|
||||||
std::vector<std::u32string> strings(lengthOf());
|
std::vector<std::u32string> strings(lengthOf());
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
strings[i] = std::move(this->e<std::u32string>(i));
|
strings[i] = std::move(this->e<std::u32string>(i));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -4846,7 +4847,7 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
|
||||||
auto shapeOf = shape::shapeOf(newShapeInfo);
|
auto shapeOf = shape::shapeOf(newShapeInfo);
|
||||||
auto stridesOf = shape::stride(newShapeInfo);
|
auto stridesOf = shape::stride(newShapeInfo);
|
||||||
|
|
||||||
Nd4jLong offset(0), subArrLen(1);
|
Nd4jLong offset = 0;
|
||||||
int n(isStrided ? 3 : 2), first, last, stride;
|
int n(isStrided ? 3 : 2), first, last, stride;
|
||||||
|
|
||||||
for (int d = rank - 1; d >= 0; --d) {
|
for (int d = rank - 1; d >= 0; --d) {
|
||||||
|
@ -4863,29 +4864,31 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
|
||||||
if(shapeOf[d] != 1)
|
if(shapeOf[d] != 1)
|
||||||
stridesOf[d] *= stride;
|
stridesOf[d] *= stride;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
subArrLen *= shapeOf[d];
|
Nd4jLong *newShapeInfo2 = newShapeInfo;
|
||||||
|
|
||||||
|
if(!keepUnitiesInShape) {
|
||||||
|
|
||||||
|
std::vector<int> dimsWithUnities;
|
||||||
|
|
||||||
|
for (uint d = 0; d < rank; ++d)
|
||||||
|
if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1)
|
||||||
|
dimsWithUnities.push_back(d);
|
||||||
|
|
||||||
|
if(!dimsWithUnities.empty())
|
||||||
|
newShapeInfo2 = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace());
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if there is possibility to set ews = 1
|
// check if there is possibility to set ews = 1
|
||||||
shape::setEws(newShapeInfo, subArrLen);
|
shape::checkStridesEwsAndOrder(newShapeInfo2);
|
||||||
|
|
||||||
NDArray result(_buffer, ShapeDescriptor(newShapeInfo), getContext(), offset + getBufferOffset());
|
NDArray result(_buffer, ShapeDescriptor(newShapeInfo2), getContext(), offset + getBufferOffset());
|
||||||
result._isView = true;
|
result._isView = true;
|
||||||
|
|
||||||
if(!keepUnitiesInShape) {
|
|
||||||
const int coeff = isStrided ? 3 : 2;
|
|
||||||
std::vector<Nd4jLong> nonUnitDims;
|
|
||||||
|
|
||||||
for (int d = 0; d < rank; ++d)
|
|
||||||
if(!(idx[coeff*d] != idx[coeff*d+1] && newShapeInfo[d+1] == 1))
|
|
||||||
nonUnitDims.push_back(newShapeInfo[d+1]);
|
|
||||||
|
|
||||||
if(nonUnitDims.size() != rank)
|
|
||||||
result.reshapei(nonUnitDims);
|
|
||||||
}
|
|
||||||
|
|
||||||
RELEASE(newShapeInfo, getContext()->getWorkspace());
|
RELEASE(newShapeInfo, getContext()->getWorkspace());
|
||||||
|
if(newShapeInfo != newShapeInfo2)
|
||||||
|
RELEASE(newShapeInfo2, getContext()->getWorkspace());
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -179,7 +179,7 @@ namespace graph {
|
||||||
nd4j_debug("Embedded graph execution finished. %i variable(s) migrated\n", cnt);
|
nd4j_debug("Embedded graph execution finished. %i variable(s) migrated\n", cnt);
|
||||||
|
|
||||||
} else if (node->hasCustomOp()) {
|
} else if (node->hasCustomOp()) {
|
||||||
// if we have something to execute - lets just execute it.
|
// now, if we have something to execute - lets just execute it.
|
||||||
auto status = node->getCustomOp()->execute(&context);
|
auto status = node->getCustomOp()->execute(&context);
|
||||||
if (status != ND4J_STATUS_OK)
|
if (status != ND4J_STATUS_OK)
|
||||||
return status;
|
return status;
|
||||||
|
@ -494,8 +494,10 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
||||||
nd4j::memory::MemoryRegistrator::getInstance()->setGraphMemoryFootprintIfGreater(h, m);
|
nd4j::memory::MemoryRegistrator::getInstance()->setGraphMemoryFootprintIfGreater(h, m);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tempFlow)
|
if (tempFlow) {
|
||||||
delete flowPath;
|
delete flowPath;
|
||||||
|
__variableSpace->setFlowPath(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,7 +98,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
Nd4jLong coords[MAX_RANK];
|
Nd4jLong coords[MAX_RANK];
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
shape::index2coords(i, target.getShapeInfo(), coords);
|
shape::index2coords(i, target.getShapeInfo(), coords);
|
||||||
const auto zOffset = shape::getOffset(target.getShapeInfo(), coords);
|
const auto zOffset = shape::getOffset(target.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) {
|
||||||
auto y = reinterpret_cast<T *>(yBuffer);
|
auto y = reinterpret_cast<T *>(yBuffer);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto temp = x[i];
|
auto temp = x[i];
|
||||||
x[i] = y[i];
|
x[i] = y[i];
|
||||||
y[i] = temp;
|
y[i] = temp;
|
||||||
|
@ -266,7 +266,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
||||||
if(result.ordering() == 'c') { // ews == 1 always here
|
if(result.ordering() == 'c') { // ews == 1 always here
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
|
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
|
||||||
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
@ -277,7 +277,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
||||||
else {
|
else {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto xOffset = result.getOffset(i);
|
auto xOffset = result.getOffset(i);
|
||||||
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
|
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
|
||||||
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES);
|
||||||
|
@ -377,7 +377,7 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
|
||||||
// loop through input array
|
// loop through input array
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
Nd4jLong coords[MAX_RANK];
|
Nd4jLong coords[MAX_RANK];
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
shape::index2coords(i, output.getShapeInfo(), coords);
|
shape::index2coords(i, output.getShapeInfo(), coords);
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords);
|
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords);
|
||||||
|
|
|
@ -22,7 +22,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
||||||
if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) {
|
if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment)
|
for (auto e = start; e < stop; e++)
|
||||||
z[e] = func(f[e], s[e], t[e]);
|
z[e] = func(f[e], s[e], t[e]);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto tOffset = this->getOffset(e);
|
auto tOffset = this->getOffset(e);
|
||||||
auto uOffset = second.getOffset(e);
|
auto uOffset = second.getOffset(e);
|
||||||
auto vOffset = third.getOffset(e);
|
auto vOffset = third.getOffset(e);
|
||||||
|
@ -44,7 +44,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto tOffset = this->getOffset(e);
|
auto tOffset = this->getOffset(e);
|
||||||
auto uOffset = second.getOffset(e);
|
auto uOffset = second.getOffset(e);
|
||||||
auto vOffset = third.getOffset(e);
|
auto vOffset = third.getOffset(e);
|
||||||
|
@ -93,7 +93,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
|
||||||
if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
|
if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment)
|
for (auto e = start; e < stop; e++)
|
||||||
z[e] = func(f[e], s[e]);
|
z[e] = func(f[e], s[e]);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto xOffset = this->getOffset(e);
|
auto xOffset = this->getOffset(e);
|
||||||
auto yOffset = other.getOffset(e);
|
auto yOffset = other.getOffset(e);
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto xOffset = this->getOffset(e);
|
auto xOffset = this->getOffset(e);
|
||||||
auto yOffset = other.getOffset(e);
|
auto yOffset = other.getOffset(e);
|
||||||
auto zOffset = target.getOffset(e);
|
auto zOffset = target.getOffset(e);
|
||||||
|
@ -156,7 +156,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
|
||||||
if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
|
if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment)
|
for (auto e = start; e < stop; e++)
|
||||||
z[e] = func(f[e]);
|
z[e] = func(f[e]);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto xOffset = this->getOffset(e);
|
auto xOffset = this->getOffset(e);
|
||||||
|
|
||||||
f[xOffset] = func(f[xOffset]);
|
f[xOffset] = func(f[xOffset]);
|
||||||
|
@ -176,7 +176,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto xOffset = this->getOffset(e);
|
auto xOffset = this->getOffset(e);
|
||||||
auto zOffset = target.getOffset(e);
|
auto zOffset = target.getOffset(e);
|
||||||
|
|
||||||
|
@ -217,7 +217,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
|
||||||
if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
|
if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment)
|
for (auto e = start; e < stop; e++)
|
||||||
z[e] = func(e, f[e]);
|
z[e] = func(e, f[e]);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -226,7 +226,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto xOffset = this->getOffset(e);
|
auto xOffset = this->getOffset(e);
|
||||||
|
|
||||||
f[xOffset] = func(e, f[xOffset]);
|
f[xOffset] = func(e, f[xOffset]);
|
||||||
|
@ -237,7 +237,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto xOffset = this->getOffset(e);
|
auto xOffset = this->getOffset(e);
|
||||||
auto zOffset = target.getOffset(e);
|
auto zOffset = target.getOffset(e);
|
||||||
|
|
||||||
|
@ -283,7 +283,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
|
||||||
if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
|
if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment)
|
for (auto e = start; e < stop; e++)
|
||||||
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -292,7 +292,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto xOffset = this->getOffset(e);
|
auto xOffset = this->getOffset(e);
|
||||||
auto yOffset = other.getOffset(e);
|
auto yOffset = other.getOffset(e);
|
||||||
|
|
||||||
|
@ -304,7 +304,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto xOffset = this->getOffset(e);
|
auto xOffset = this->getOffset(e);
|
||||||
auto yOffset = other.getOffset(e);
|
auto yOffset = other.getOffset(e);
|
||||||
auto zOffset = target.getOffset(e);
|
auto zOffset = target.getOffset(e);
|
||||||
|
|
|
@ -163,15 +163,44 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
auto loopKind = nd4j::LoopKind::deduceKindOfLoopBroadcast(hXShapeInfo, hYShapeInfo, hZShapeInfo);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), LIBND4J_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto xLen = shape::length(hXShapeInfo);
|
Nd4jLong numTads = 0;
|
||||||
auto yLen = shape::length(hYShapeInfo);
|
|
||||||
auto numTads = xLen / yLen;
|
switch (loopKind) {
|
||||||
|
case nd4j::LoopKind::BROADCAST_SCALAR_X: {
|
||||||
|
numTads = shape::length(hXShapeInfo);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case nd4j::LoopKind::BROADCAST_SCALAR_Y: {
|
||||||
|
numTads = shape::length(hYShapeInfo);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case nd4j::LoopKind::BROADCAST_3D: {
|
||||||
|
numTads = shape::sizeAt(hZShapeInfo, 0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case nd4j::LoopKind::BROADCAST_4D: {
|
||||||
|
numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case nd4j::LoopKind::BROADCAST_5D: {
|
||||||
|
numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default: {
|
||||||
|
auto xLen = shape::length(hXShapeInfo);
|
||||||
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
|
numTads = xLen / yLen;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
samediff::Threads::parallel_tad(func, 0, numTads);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1291,7 +1291,7 @@ void pullRowsGeneric(void *vx,
|
||||||
_threads = nd4j::math::nd4j_min<int>(_threads, nd4j::Environment::getInstance()->maxThreads());
|
_threads = nd4j::math::nd4j_min<int>(_threads, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto idx = start; idx < stop; idx += increment) {
|
for (auto idx = start; idx < stop; idx++) {
|
||||||
auto xTadOffsetForBlock = tadOffsets[indexes[idx]];
|
auto xTadOffsetForBlock = tadOffsets[indexes[idx]];
|
||||||
auto zTadOffsetForBlock = zTadOffsets[idx];
|
auto zTadOffsetForBlock = zTadOffsets[idx];
|
||||||
|
|
||||||
|
@ -1356,7 +1356,7 @@ void tearGeneric(void *vx,
|
||||||
auto numTads = shape::length(hXShapeInfo) / tadLength;
|
auto numTads = shape::length(hXShapeInfo) / tadLength;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto hZ = reinterpret_cast<T *>(targets[i]);
|
auto hZ = reinterpret_cast<T *>(targets[i]);
|
||||||
auto s = hX + tadOffsets[i];
|
auto s = hX + tadOffsets[i];
|
||||||
|
|
||||||
|
@ -1478,7 +1478,7 @@ void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZS
|
||||||
auto dZ = reinterpret_cast<T **>(dz);
|
auto dZ = reinterpret_cast<T **>(dz);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto f = start; f < stop; f += increment) {
|
for (auto f = start; f < stop; f++) {
|
||||||
auto hX = reinterpret_cast<T *>(dX[f]);
|
auto hX = reinterpret_cast<T *>(dX[f]);
|
||||||
//auto hZ = reinterpret_cast<T *>(dZ[f]);
|
//auto hZ = reinterpret_cast<T *>(dZ[f]);
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ namespace nd4j {
|
||||||
TypeCast::convertGeneric<T2, T>(nullptr, tmp, length, buffer);
|
TypeCast::convertGeneric<T2, T>(nullptr, tmp, length, buffer);
|
||||||
#else
|
#else
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment)
|
for (auto e = start; e < stop; e++)
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ namespace nd4j {
|
||||||
TypeCast::convertGeneric<float, T>(nullptr, tmp, length, buffer);
|
TypeCast::convertGeneric<float, T>(nullptr, tmp, length, buffer);
|
||||||
#else
|
#else
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment)
|
for (auto e = start; e < stop; e++)
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ namespace nd4j {
|
||||||
|
|
||||||
#else
|
#else
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment)
|
for (auto e = start; e < stop; e++)
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ namespace nd4j {
|
||||||
TypeCast::convertGeneric<float16, T>(nullptr, tmp, length, buffer);
|
TypeCast::convertGeneric<float16, T>(nullptr, tmp, length, buffer);
|
||||||
#else
|
#else
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment)
|
for (auto e = start; e < stop; e++)
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,7 @@ namespace nd4j {
|
||||||
virtual void putVariable(int id, Variable *variable);
|
virtual void putVariable(int id, Variable *variable);
|
||||||
virtual void putVariable(int id, NDArray *array);
|
virtual void putVariable(int id, NDArray *array);
|
||||||
virtual void putVariable(int id, int idx, NDArray *array);
|
virtual void putVariable(int id, int idx, NDArray *array);
|
||||||
|
virtual void putVariable(int id, int idx, NDArray &array);
|
||||||
virtual void putVariable(int id, int idx, Variable *array);
|
virtual void putVariable(int id, int idx, Variable *array);
|
||||||
|
|
||||||
virtual void replaceVariable(Variable *variable);
|
virtual void replaceVariable(Variable *variable);
|
||||||
|
|
|
@ -100,6 +100,7 @@ namespace nd4j {
|
||||||
virtual void putVariable(int id, Variable *variable);
|
virtual void putVariable(int id, Variable *variable);
|
||||||
virtual void putVariable(int id, NDArray *array);
|
virtual void putVariable(int id, NDArray *array);
|
||||||
virtual void putVariable(int id, int idx, NDArray *array);
|
virtual void putVariable(int id, int idx, NDArray *array);
|
||||||
|
virtual void putVariable(int id, int idx, NDArray &array);
|
||||||
virtual void putVariable(int id, int idx, Variable *array);
|
virtual void putVariable(int id, int idx, Variable *array);
|
||||||
|
|
||||||
virtual void dropVariable(std::pair<int,int> &pair);
|
virtual void dropVariable(std::pair<int,int> &pair);
|
||||||
|
|
|
@ -1088,8 +1088,23 @@ namespace nd4j {
|
||||||
if (e < node->input()->size() - 1)
|
if (e < node->input()->size() - 1)
|
||||||
nd4j_printf(", ", "");
|
nd4j_printf(", ", "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->opType() == OpType_CUSTOM) {
|
||||||
|
auto ctx = node->protoContext();
|
||||||
|
if (ctx->getIArguments()->size() > 0) {
|
||||||
|
printf("]; iArgs: [");
|
||||||
|
|
||||||
|
for (int e = 0; e < ctx->getIArguments()->size(); e++) {
|
||||||
|
printf("%i", ctx->getIArguments()->at(e));
|
||||||
|
if (e < ctx->getIArguments()->size() - 1)
|
||||||
|
nd4j_printf(", ", "");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
nd4j_printf("]; \n", "");
|
nd4j_printf("]; \n", "");
|
||||||
|
|
||||||
|
|
||||||
// printf("\n");
|
// printf("\n");
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,8 +60,11 @@ namespace nd4j {
|
||||||
result->_name = this->_name;
|
result->_name = this->_name;
|
||||||
result->_index = this->_index;
|
result->_index = this->_index;
|
||||||
|
|
||||||
if (this->_ndarray != nullptr)
|
if (this->_ndarray != nullptr) {
|
||||||
result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering()));
|
result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering()));
|
||||||
|
result->_readOnly = false;
|
||||||
|
result->_removable = true;
|
||||||
|
}
|
||||||
|
|
||||||
if (this->_list != nullptr)
|
if (this->_list != nullptr)
|
||||||
result->_list = this->_list->clone();
|
result->_list = this->_list->clone();
|
||||||
|
|
|
@ -191,6 +191,9 @@ namespace nd4j {
|
||||||
_current->putVariable(id, array);
|
_current->putVariable(id, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void nd4j::graph::VariableProxy::putVariable(int id, int idx, NDArray &array) {
|
||||||
|
_current->putVariable(id, idx, array);
|
||||||
|
}
|
||||||
|
|
||||||
void VariableProxy::putVariable(int id, int idx, NDArray *array) {
|
void VariableProxy::putVariable(int id, int idx, NDArray *array) {
|
||||||
_current->putVariable(id, idx, array);
|
_current->putVariable(id, idx, array);
|
||||||
|
|
|
@ -263,19 +263,19 @@ namespace nd4j {
|
||||||
void nd4j::graph::VariableSpace::putVariable(int id, Variable *variable) {
|
void nd4j::graph::VariableSpace::putVariable(int id, Variable *variable) {
|
||||||
// we don't want to add variables more then once
|
// we don't want to add variables more then once
|
||||||
if (_variables.count(id) > 0 || _temporary.count(id) > 0) {
|
if (_variables.count(id) > 0 || _temporary.count(id) > 0) {
|
||||||
// nd4j_verbose("Trying to update variable for node_%i\n", id);
|
|
||||||
|
|
||||||
auto local = id < 0 ? _variables.at(id) : _temporary.at(id);
|
auto local = id < 0 ? _variables.at(id) : _temporary.at(id);
|
||||||
|
|
||||||
if (!local->hasNDArray() && variable->hasNDArray()) {
|
if (!local->hasNDArray() && variable->hasNDArray()) {
|
||||||
// nd4j_verbose("Saving variable for node_%i\n", id);
|
|
||||||
local->setNDArray(variable->getNDArray());
|
local->setNDArray(variable->getNDArray());
|
||||||
|
|
||||||
|
// we're inheriting this from Variable
|
||||||
|
local->markReadOnly(variable->isReadOnly());
|
||||||
|
local->markRemovable(variable->isRemovable());
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
//nd4j_debug("Adding Variable to Space: id: %i; Array is null: %i;\n", id, variable->getNDArray() == nullptr);
|
|
||||||
|
|
||||||
_varmap.lock();
|
_varmap.lock();
|
||||||
|
|
||||||
_handles->emplace_back(variable);
|
_handles->emplace_back(variable);
|
||||||
|
@ -314,6 +314,21 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void nd4j::graph::VariableSpace::putVariable(int id, int idx, NDArray &array) {
|
||||||
|
auto *var = new nd4j::graph::Variable(&array, "", id, idx);
|
||||||
|
var->markRemovable(false);
|
||||||
|
var->markReadOnly(true);
|
||||||
|
|
||||||
|
// let's see if this op needs
|
||||||
|
bool d = this->hasVariable(id, idx);
|
||||||
|
|
||||||
|
this->putVariable(id, var);
|
||||||
|
|
||||||
|
// if var for this nodeid already exists - we'll just delete variable
|
||||||
|
if (d)
|
||||||
|
delete var;
|
||||||
|
}
|
||||||
|
|
||||||
void nd4j::graph::VariableSpace::putVariable(int id, NDArray *array) {
|
void nd4j::graph::VariableSpace::putVariable(int id, NDArray *array) {
|
||||||
auto *var = new nd4j::graph::Variable(array);
|
auto *var = new nd4j::graph::Variable(array);
|
||||||
this->putVariable(id, var);
|
this->putVariable(id, var);
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <pointercast.h>
|
#include <pointercast.h>
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace graph {
|
namespace graph {
|
||||||
|
@ -65,6 +66,9 @@ namespace nd4j {
|
||||||
|
|
||||||
// total amount of memory used during execution
|
// total amount of memory used during execution
|
||||||
Nd4jLong _memoryTotal = 0L;
|
Nd4jLong _memoryTotal = 0L;
|
||||||
|
|
||||||
|
std::vector<std::string> _inputShapes;
|
||||||
|
std::vector<std::string> _outputShapes;
|
||||||
public:
|
public:
|
||||||
NodeProfile() = default;
|
NodeProfile() = default;
|
||||||
~NodeProfile() = default;
|
~NodeProfile() = default;
|
||||||
|
@ -84,10 +88,15 @@ namespace nd4j {
|
||||||
void setObjectsSize(Nd4jLong bytes);
|
void setObjectsSize(Nd4jLong bytes);
|
||||||
void setTotalSize(Nd4jLong bytes);
|
void setTotalSize(Nd4jLong bytes);
|
||||||
|
|
||||||
Nd4jLong getActivationsSize();
|
void addInputShape(Nd4jLong *shapeInfo);
|
||||||
Nd4jLong getTemporarySize();
|
void addOutputShape(Nd4jLong *shapeInfo);
|
||||||
Nd4jLong getObjectsSize();
|
|
||||||
Nd4jLong getTotalSize();
|
Nd4jLong getActivationsSize() const;
|
||||||
|
Nd4jLong getTemporarySize() const;
|
||||||
|
Nd4jLong getObjectsSize() const;
|
||||||
|
Nd4jLong getTotalSize() const;
|
||||||
|
|
||||||
|
Nd4jLong getExecutionTime() const;
|
||||||
|
|
||||||
std::string& name();
|
std::string& name();
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
#include <graph/profiling/GraphProfile.h>
|
#include <graph/profiling/GraphProfile.h>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <templatemath.h>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace graph {
|
namespace graph {
|
||||||
|
@ -184,8 +186,25 @@ namespace nd4j {
|
||||||
if (_profiles.empty())
|
if (_profiles.empty())
|
||||||
nd4j_printf("No nodes in graph\n","");
|
nd4j_printf("No nodes in graph\n","");
|
||||||
|
|
||||||
for (auto v: _profiles)
|
// printint out stuff
|
||||||
|
std::vector<NodeProfile*> sorted;
|
||||||
|
for (auto v: _profiles) {
|
||||||
v->printOut();
|
v->printOut();
|
||||||
|
sorted.emplace_back(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (_profiles.size() > 1) {
|
||||||
|
// building hot spots
|
||||||
|
std::sort(sorted.begin(), sorted.end(), [](const NodeProfile *a, const NodeProfile *b) -> bool {
|
||||||
|
return a->getExecutionTime() > b->getExecutionTime();
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j_printf("\nTop 30 reports by EXEC:\n", "");
|
||||||
|
auto limit = nd4j::math::nd4j_min<int>(30, sorted.size());
|
||||||
|
for (int e = 0; e < limit; e++) {
|
||||||
|
sorted[e]->printOut();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
nd4j_printf("\nSpecial timers:\n", "");
|
nd4j_printf("\nSpecial timers:\n", "");
|
||||||
if (_timings.empty())
|
if (_timings.empty())
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace nd4j {
|
||||||
// graph->printOut();
|
// graph->printOut();
|
||||||
|
|
||||||
// warm up
|
// warm up
|
||||||
for (int e = 0; e < 1000; e++) {
|
for (int e = 0; e < iterations; e++) {
|
||||||
FlowPath fp;
|
FlowPath fp;
|
||||||
|
|
||||||
auto _vs = varSpace->clone();
|
auto _vs = varSpace->clone();
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
#include <graph/profiling/NodeProfile.h>
|
#include <graph/profiling/NodeProfile.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace graph {
|
namespace graph {
|
||||||
|
@ -35,9 +36,23 @@ namespace nd4j {
|
||||||
nd4j_printf(" Memory: ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", _memoryActivations / _merges, _memoryTemporary / _merges, _memoryObjects / _merges, _memoryTotal / _merges);
|
nd4j_printf(" Memory: ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", _memoryActivations / _merges, _memoryTemporary / _merges, _memoryObjects / _merges, _memoryTotal / _merges);
|
||||||
nd4j_printf(" Time: PREP: %lld ns; EXEC: %lld ns; TTL: %lld ns;\n", _preparationTime / _merges, _executionTime / _merges, _totalTime / _merges);
|
nd4j_printf(" Time: PREP: %lld ns; EXEC: %lld ns; TTL: %lld ns;\n", _preparationTime / _merges, _executionTime / _merges, _totalTime / _merges);
|
||||||
nd4j_printf(" PREP: INPUT: %lld ns; SHAPE: %lld ns; ARRAY: %lld ns;\n", _inputTime / _merges, _shapeTime / _merges, _arrayTime / _merges);
|
nd4j_printf(" PREP: INPUT: %lld ns; SHAPE: %lld ns; ARRAY: %lld ns;\n", _inputTime / _merges, _shapeTime / _merges, _arrayTime / _merges);
|
||||||
|
|
||||||
|
std::string inputs;
|
||||||
|
std::string outputs;
|
||||||
|
|
||||||
|
int cnt = 0;
|
||||||
|
for (const auto &v: _inputShapes)
|
||||||
|
inputs += v + " ";
|
||||||
|
|
||||||
|
for (const auto &v: _outputShapes)
|
||||||
|
outputs += v + " ";
|
||||||
|
|
||||||
|
|
||||||
|
nd4j_printf(" Inputs: %s\n", inputs.c_str());
|
||||||
|
nd4j_printf(" Outputs: %s\n", outputs.c_str());
|
||||||
};
|
};
|
||||||
|
|
||||||
Nd4jLong NodeProfile::getActivationsSize() {
|
Nd4jLong NodeProfile::getActivationsSize() const {
|
||||||
return _memoryActivations;
|
return _memoryActivations;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,15 +68,15 @@ namespace nd4j {
|
||||||
_inputTime = time;
|
_inputTime = time;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong NodeProfile::getTemporarySize() {
|
Nd4jLong NodeProfile::getTemporarySize() const{
|
||||||
return _memoryTemporary;
|
return _memoryTemporary;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong NodeProfile::getObjectsSize() {
|
Nd4jLong NodeProfile::getObjectsSize() const{
|
||||||
return _memoryObjects;
|
return _memoryObjects;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong NodeProfile::getTotalSize() {
|
Nd4jLong NodeProfile::getTotalSize() const{
|
||||||
return _memoryTotal;
|
return _memoryTotal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,6 +112,18 @@ namespace nd4j {
|
||||||
_memoryTotal = bytes;
|
_memoryTotal = bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong NodeProfile::getExecutionTime() const {
|
||||||
|
return _executionTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NodeProfile::addInputShape(Nd4jLong *shapeInfo) {
|
||||||
|
_inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
|
||||||
|
}
|
||||||
|
|
||||||
|
void NodeProfile::addOutputShape(Nd4jLong *shapeInfo) {
|
||||||
|
_outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
|
||||||
|
}
|
||||||
|
|
||||||
void NodeProfile::merge(NodeProfile *other) {
|
void NodeProfile::merge(NodeProfile *other) {
|
||||||
_merges += other->_merges;
|
_merges += other->_merges;
|
||||||
_memoryObjects += other->_memoryObjects;
|
_memoryObjects += other->_memoryObjects;
|
||||||
|
@ -110,6 +137,9 @@ namespace nd4j {
|
||||||
_shapeTime += other->_shapeTime;
|
_shapeTime += other->_shapeTime;
|
||||||
_arrayTime += other->_arrayTime;
|
_arrayTime += other->_arrayTime;
|
||||||
_inputTime += other->_inputTime;
|
_inputTime += other->_inputTime;
|
||||||
|
|
||||||
|
_inputShapes = other->_inputShapes;
|
||||||
|
_outputShapes = other->_outputShapes;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string& NodeProfile::name() {
|
std::string& NodeProfile::name() {
|
||||||
|
@ -129,6 +159,9 @@ namespace nd4j {
|
||||||
_shapeTime = other->_shapeTime;
|
_shapeTime = other->_shapeTime;
|
||||||
_arrayTime = other->_arrayTime;
|
_arrayTime = other->_arrayTime;
|
||||||
_inputTime = other->_inputTime;
|
_inputTime = other->_inputTime;
|
||||||
|
|
||||||
|
_inputShapes = other->_inputShapes;
|
||||||
|
_outputShapes = other->_outputShapes;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -37,12 +37,13 @@ namespace nd4j {
|
||||||
class ND4J_EXPORT LoopKind {
|
class ND4J_EXPORT LoopKind {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON};
|
enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D };
|
||||||
|
|
||||||
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
|
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -82,6 +83,57 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd
|
||||||
return COMMON;
|
return COMMON;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
|
||||||
|
auto xRank = shape::rank(xShapeInfo);
|
||||||
|
auto yRank = shape::rank(yShapeInfo);
|
||||||
|
auto zRank = shape::rank(zShapeInfo);
|
||||||
|
|
||||||
|
auto xOrder = shape::order(xShapeInfo);
|
||||||
|
auto yOrder = shape::order(yShapeInfo);
|
||||||
|
auto zOrder = shape::order(zShapeInfo);
|
||||||
|
|
||||||
|
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
|
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
|
|
||||||
|
bool bNDLoopsRanks = (xRank == zRank && yRank <= xRank && yRank >= 2);
|
||||||
|
|
||||||
|
int countUnityDimsInY = 0, countUnityDimsInX = 0;
|
||||||
|
for (int i = 0; i < xRank; i++) {
|
||||||
|
if (i < yRank)
|
||||||
|
countUnityDimsInY += (1 == shape::sizeAt(yShapeInfo, i)) ? 1 : 0;
|
||||||
|
countUnityDimsInX += (1 == shape::sizeAt(xShapeInfo, i)) ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bNotCommonVectorCase = (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1);
|
||||||
|
|
||||||
|
if (3 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
|
||||||
|
return nd4j::LoopKind::BROADCAST_3D;
|
||||||
|
if (4 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
|
||||||
|
return nd4j::LoopKind::BROADCAST_4D;
|
||||||
|
if (5 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
|
||||||
|
return nd4j::LoopKind::BROADCAST_5D;
|
||||||
|
|
||||||
|
|
||||||
|
if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) {
|
||||||
|
// we validate that shapes are equal till the last dim
|
||||||
|
for (int e = 0; e < xRank - 1; e++) {
|
||||||
|
if (xShapeInfo[e+1] != yShapeInfo[e+1])
|
||||||
|
return COMMON;
|
||||||
|
}
|
||||||
|
|
||||||
|
// now, if one of the shapes has 1 as last dim
|
||||||
|
auto detect = xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0;
|
||||||
|
|
||||||
|
if (detect == 1)
|
||||||
|
return nd4j::LoopKind::BROADCAST_SCALAR_Y;
|
||||||
|
else if (detect == -1)
|
||||||
|
return nd4j::LoopKind::BROADCAST_SCALAR_X;
|
||||||
|
}
|
||||||
|
|
||||||
|
return nd4j::LoopKind::COMMON;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
|
LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,13 @@ namespace nd4j {
|
||||||
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
||||||
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides
|
||||||
|
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
|
||||||
|
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
|
||||||
|
*/
|
||||||
|
static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);
|
||||||
|
|
|
@ -50,11 +50,13 @@ namespace nd4j {
|
||||||
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
|
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
|
||||||
|
|
||||||
// evaluate shapeInfo of permuted array
|
// evaluate shapeInfo of permuted array
|
||||||
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
|
||||||
|
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
|
||||||
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
||||||
|
|
||||||
// evaluate shapeInfo of transposed array
|
// evaluate shapeInfo of transposed array
|
||||||
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace);
|
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
|
||||||
|
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
|
||||||
|
|
||||||
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
|
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
|
||||||
|
|
||||||
|
@ -97,6 +99,8 @@ namespace nd4j {
|
||||||
static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo);
|
static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo);
|
||||||
static std::string strideAsString(const NDArray* array);
|
static std::string strideAsString(const NDArray* array);
|
||||||
|
|
||||||
|
static std::string shapeInfoAsString(const Nd4jLong* shapeInfo);
|
||||||
|
|
||||||
static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo);
|
static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo);
|
||||||
|
|
||||||
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
|
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
|
||||||
|
@ -176,6 +180,17 @@ namespace nd4j {
|
||||||
return (numStrings + 1) * sizeof(Nd4jLong);
|
return (numStrings + 1) * sizeof(Nd4jLong);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method selects strides based on dimentions required for broadcasting
|
||||||
|
* @param const pointer to input (Y) shape info for strides selection
|
||||||
|
* @param rank of input (X) to broadcasting
|
||||||
|
* @param dimentions size
|
||||||
|
* @param const pointer to dimentions for broadcasting
|
||||||
|
* @param pointer to output strides have to be pre allocated by 0
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static void copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* check whether arr1/arr2 is sub-array of arr2/arr1,
|
* check whether arr1/arr2 is sub-array of arr2/arr1,
|
||||||
* this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1
|
* this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1
|
||||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
||||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
|
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
|
||||||
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
|
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
|
||||||
|
|
||||||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; // shape of sub-arrays (same for all for them)
|
||||||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||||
|
|
||||||
if (numOfSubArrs > 0)
|
if (numOfSubArrs > 0)
|
||||||
|
|
|
@ -49,7 +49,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
case nd4j::LoopKind::EWS1: {
|
case nd4j::LoopKind::EWS1: {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
case nd4j::LoopKind::EWSNONZERO: {
|
case nd4j::LoopKind::EWSNONZERO: {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
case nd4j::LoopKind::RANK1: {
|
case nd4j::LoopKind::RANK1: {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
shape::updateStrides(2, tadShape, newStride, 'c');
|
shape::updateStrides(2, tadShape, newStride, 'c');
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
shape::updateStrides(3, tadShape, newStride, 'c');
|
shape::updateStrides(3, tadShape, newStride, 'c');
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
|
@ -170,7 +170,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
shape::updateStrides(4, tadShape, newStride, 'c');
|
shape::updateStrides(4, tadShape, newStride, 'c');
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
|
@ -201,7 +201,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
shape::updateStrides(5, tadShape, newStride, 'c');
|
shape::updateStrides(5, tadShape, newStride, 'c');
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
|
@ -234,7 +234,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
|
@ -258,7 +258,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
|
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
|
@ -284,7 +284,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
|
|
|
@ -43,23 +43,30 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N
|
||||||
|
|
||||||
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||||
|
|
||||||
NDArray aPR = a->permute(permutAt);
|
// check whether permutation is necessary
|
||||||
NDArray bPR = b->permute(permutBt);
|
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
|
||||||
|
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
|
||||||
|
|
||||||
// check whether reshape is necessary
|
// check whether reshape is necessary
|
||||||
if(!aPR.isSameShape(shapeAt))
|
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
|
||||||
aPR.reshapei( shapeAt);
|
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
|
||||||
if(!bPR.isSameShape(shapeBt))
|
|
||||||
bPR.reshapei( shapeBt);
|
|
||||||
|
|
||||||
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
|
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
|
||||||
|
|
||||||
c->reshapei(outShape);
|
c->reshapei(outShape);
|
||||||
|
|
||||||
|
if(aP != aPR)
|
||||||
|
delete aPR;
|
||||||
|
if(bP != bPR)
|
||||||
|
delete bPR;
|
||||||
|
if(a != aP)
|
||||||
|
delete aP;
|
||||||
|
if(b != bP)
|
||||||
|
delete bP;
|
||||||
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) {
|
void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) {
|
||||||
|
|
||||||
|
@ -67,32 +74,38 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
|
||||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||||
ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt);
|
ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt);
|
||||||
|
|
||||||
NDArray *cP(c), *cPR(c);
|
|
||||||
|
|
||||||
// check whether permutation is required
|
// check whether permutation is required
|
||||||
if(!permutForC.empty())
|
NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC));
|
||||||
cP = new NDArray(c->permute(permutForC));
|
|
||||||
|
|
||||||
auto aPR = a->permute(permutAt);
|
// check whether permutation is necessary
|
||||||
auto bPR = b->permute(permutBt);
|
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
|
||||||
|
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
|
||||||
|
|
||||||
// check whether reshape is necessary
|
// check whether reshape is necessary
|
||||||
if(!aPR.isSameShape(shapeAt))
|
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
|
||||||
aPR.reshapei(shapeAt);
|
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
|
||||||
if(!bPR.isSameShape(shapeBt))
|
|
||||||
bPR.reshapei(shapeBt);
|
|
||||||
|
|
||||||
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
|
std::vector<Nd4jLong> requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)};
|
||||||
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
|
|
||||||
|
|
||||||
mmul(&aPR, &bPR, cPR, 1.0, 0.0);
|
NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false));
|
||||||
|
|
||||||
|
mmul(aPR, bPR, cPR, 1.0, 0.0);
|
||||||
|
|
||||||
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
|
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
|
||||||
cP->assign(cPR);
|
cP->assign(cPR);
|
||||||
|
|
||||||
if(cPR != c)
|
if(aP != aPR)
|
||||||
|
delete aPR;
|
||||||
|
if(bP != bPR)
|
||||||
|
delete bPR;
|
||||||
|
if(a != aP)
|
||||||
|
delete aP;
|
||||||
|
if(b != bP)
|
||||||
|
delete bP;
|
||||||
|
|
||||||
|
if(cP != cPR)
|
||||||
delete cPR;
|
delete cPR;
|
||||||
if(cP != c)
|
if(c != cP)
|
||||||
delete cP;
|
delete cP;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,7 +142,7 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
|
||||||
if(!whatToDoWithC.empty()) {
|
if(!whatToDoWithC.empty()) {
|
||||||
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
|
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
|
||||||
for(int i = 0; i < cArrs.size()-1; ++i)
|
for(int i = 0; i < cArrs.size()-1; ++i)
|
||||||
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false)); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
||||||
}
|
}
|
||||||
|
|
||||||
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
|
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
|
||||||
|
@ -208,7 +221,7 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
|
||||||
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
|
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
|
||||||
if(isAVector && bRank == 2) {
|
if(isAVector && bRank == 2) {
|
||||||
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
|
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
|
||||||
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N}
|
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N}
|
||||||
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
|
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
|
||||||
delete A2;
|
delete A2;
|
||||||
delete C2;
|
delete C2;
|
||||||
|
|
|
@ -139,5 +139,15 @@ namespace nd4j {
|
||||||
return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace);
|
return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) {
|
||||||
|
|
||||||
|
Nd4jLong *outShapeInfo = nullptr;
|
||||||
|
ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong);
|
||||||
|
|
||||||
|
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo);
|
||||||
|
|
||||||
|
return outShapeInfo;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -75,10 +75,23 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
|
||||||
permutBt = axesB;
|
permutBt = axesB;
|
||||||
permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());
|
permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());
|
||||||
|
|
||||||
|
// if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case
|
||||||
|
uint i1, i2;
|
||||||
|
for(i1 = 0; i1 < aRank; ++i1)
|
||||||
|
if(permutAt[i1] != i1)
|
||||||
|
break;
|
||||||
|
if(i1 == aRank)
|
||||||
|
permutAt = {};
|
||||||
|
for(i2 = 0; i2 < bRank; ++i2)
|
||||||
|
if(permutBt[i2] != i2)
|
||||||
|
break;
|
||||||
|
if(i2 == bRank)
|
||||||
|
permutBt = {};
|
||||||
|
|
||||||
Nd4jLong n2 = 1;
|
Nd4jLong n2 = 1;
|
||||||
for (int i = 0; i < axeAsize; i++)
|
for (int i = 0; i < axeAsize; i++)
|
||||||
n2 *= aShapeInfo[axesA[i] + 1];
|
n2 *= aShapeInfo[axesA[i] + 1];
|
||||||
shapeAt = {-1, n2};
|
shapeAt = {shape::length(aShapeInfo) / n2, n2};
|
||||||
|
|
||||||
std::vector<Nd4jLong> oldShapeA;
|
std::vector<Nd4jLong> oldShapeA;
|
||||||
oldShapeA.resize(list_A.size());
|
oldShapeA.resize(list_A.size());
|
||||||
|
@ -89,7 +102,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
|
||||||
Nd4jLong n3 = 1;
|
Nd4jLong n3 = 1;
|
||||||
for (int i = 0; i < axeBsize; i++)
|
for (int i = 0; i < axeBsize; i++)
|
||||||
n3 *= bShapeInfo[axesB[i] + 1];
|
n3 *= bShapeInfo[axesB[i] + 1];
|
||||||
shapeBt = {n3, -1};
|
shapeBt = {n3, shape::length(bShapeInfo) / n3};
|
||||||
|
|
||||||
std::vector<Nd4jLong> oldShapeB;
|
std::vector<Nd4jLong> oldShapeB;
|
||||||
oldShapeB.resize(list_B.size());
|
oldShapeB.resize(list_B.size());
|
||||||
|
@ -300,32 +313,37 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
||||||
return outShape;
|
return outShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// evaluate shapeInfo of permuted array
|
// evaluate shapeInfo of permuted array
|
||||||
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
|
||||||
|
|
||||||
if (!arr.nonNull())
|
if (!arr.nonNull())
|
||||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr!");
|
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
|
||||||
|
|
||||||
if (rank != arr.rankOf())
|
if (rank != arr.rankOf())
|
||||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable!");
|
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
|
||||||
|
|
||||||
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
||||||
// allocate memory for new array - shapeInfo
|
|
||||||
|
|
||||||
Nd4jLong *shapeInfoNew = nullptr;
|
// allocate memory for new array - shapeInfo
|
||||||
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
|
Nd4jLong *shapeInfoNew = nullptr;
|
||||||
// copy arr _shapeInfo into new array
|
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
|
||||||
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
|
|
||||||
// perform buffer permutation
|
|
||||||
shape::doPermuteShapeInfo(shapeInfoNew, dimensions);
|
|
||||||
|
|
||||||
ShapeDescriptor descriptor(shapeInfoNew);
|
// copy arr _shapeInfo into new array
|
||||||
RELEASE(shapeInfoNew, workspace);
|
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
|
||||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// perform buffer permutation
|
||||||
|
shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());
|
||||||
|
|
||||||
|
if(setContigStrides)
|
||||||
|
shape::updateStrides(shapeInfoNew, arr.ordering());
|
||||||
|
|
||||||
|
ShapeDescriptor descriptor(shapeInfoNew);
|
||||||
|
|
||||||
|
RELEASE(shapeInfoNew, workspace);
|
||||||
|
|
||||||
|
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// evaluate shapeInfo of permuted array
|
// evaluate shapeInfo of permuted array
|
||||||
|
@ -337,14 +355,14 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// evaluate shapeInfo of transposed array
|
// evaluate shapeInfo of transposed array
|
||||||
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
|
||||||
|
|
||||||
int rank = arr.rankOf();
|
int rank = arr.rankOf();
|
||||||
std::vector<int> dimensions(rank);
|
std::vector<int> dimensions(rank);
|
||||||
for (int i = 0; i < rank; ++i)
|
for (int i = 0; i < rank; ++i)
|
||||||
dimensions[i] = rank - 1 - i;
|
dimensions[i] = rank - 1 - i;
|
||||||
|
|
||||||
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace);
|
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, setContigStrides);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -653,6 +671,26 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ShapeUtils::shapeInfoAsString(const Nd4jLong* shapeInfo) {
|
||||||
|
|
||||||
|
if(!shapeInfo)
|
||||||
|
throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !");
|
||||||
|
|
||||||
|
std::string result;
|
||||||
|
|
||||||
|
int len = shape::shapeInfoLength(shapeInfo[0]);
|
||||||
|
|
||||||
|
result.append("[");
|
||||||
|
for (int e = 0; e < len; e++) {
|
||||||
|
result += flatbuffers::NumToString(shapeInfo[e]);
|
||||||
|
if (e < len - 1)
|
||||||
|
result.append(", ");
|
||||||
|
}
|
||||||
|
result.append("]");
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
std::string ShapeUtils::shapeAsString(const int rank, const Nd4jLong* shapeInfo) {
|
std::string ShapeUtils::shapeAsString(const int rank, const Nd4jLong* shapeInfo) {
|
||||||
if(!shapeInfo)
|
if(!shapeInfo)
|
||||||
|
@ -1019,6 +1057,29 @@ std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const
|
||||||
return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
|
return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides) {
|
||||||
|
|
||||||
|
int yRank = shape::rank(inShapeInfo);
|
||||||
|
auto yOrigStride = shape::stride(inShapeInfo);
|
||||||
|
|
||||||
|
if (yRank == nRank) {
|
||||||
|
for (int i = 0; i < yRank; ++i) {
|
||||||
|
// x[2,3,4] * y[2,1,4] = z[2,3,4]
|
||||||
|
outStrides[i] = (1 == shape::sizeAt(inShapeInfo, i)) ? 0 : yOrigStride[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
auto dimEx = nd4j::ShapeUtils::evalDimsToExclude(nRank, dimsSize, dims);
|
||||||
|
|
||||||
|
for (int i = 0, it = 0; i < nRank; ++i) {
|
||||||
|
auto nCount = std::count(dimEx.cbegin(), dimEx.cend(), i);
|
||||||
|
outStrides[i] = (0 == nCount) ? yOrigStride[it++] : 0;
|
||||||
|
if (it == yRank)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
/*
|
/*
|
||||||
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
|
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -40,6 +40,7 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <helpers/TAD.h>
|
#include <helpers/TAD.h>
|
||||||
|
#include <helpers/LoopKind.h>
|
||||||
|
|
||||||
#include "legacy_ops.h"
|
#include "legacy_ops.h"
|
||||||
|
|
||||||
|
@ -122,6 +123,7 @@ namespace functions {
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ,
|
Nd4jLong *tadOffsetZ,
|
||||||
|
nd4j::LoopKind::Kind loopKind,
|
||||||
uint64_t start,
|
uint64_t start,
|
||||||
uint64_t stop);
|
uint64_t stop);
|
||||||
|
|
||||||
|
@ -149,6 +151,7 @@ namespace functions {
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ,
|
Nd4jLong *tadOffsetZ,
|
||||||
|
nd4j::LoopKind::Kind loopKind,
|
||||||
uint64_t start,
|
uint64_t start,
|
||||||
uint64_t stop);
|
uint64_t stop);
|
||||||
|
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <loops/TrueBroadcastHelper.h>
|
#include <loops/TrueBroadcastHelper.h>
|
||||||
#include <ops/ops.h>
|
#include <ops/ops.h>
|
||||||
|
@ -24,226 +24,268 @@
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y, typename Z>
|
template <typename X, typename Y, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||||
|
|
||||||
|
|
||||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||||
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
|
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
|
||||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||||
|
|
||||||
const auto xShapeInfo = xArr.getShapeInfo();
|
const auto xShapeInfo = xArr.getShapeInfo();
|
||||||
const auto yShapeInfo = yArr.getShapeInfo();
|
const auto yShapeInfo = yArr.getShapeInfo();
|
||||||
const auto zShapeInfo = zArr.getShapeInfo();
|
const auto zShapeInfo = zArr.getShapeInfo();
|
||||||
|
|
||||||
const int xRank = xArr.rankOf();
|
const int xRank = xArr.rankOf();
|
||||||
const int yRank = yArr.rankOf();
|
const int yRank = yArr.rankOf();
|
||||||
const int zRank = zArr.rankOf();
|
const int zRank = zArr.rankOf();
|
||||||
|
|
||||||
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank &&
|
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() &&
|
||||||
1 == yArr.ews() && 'c' == yArr.ordering() &&
|
1 == yArr.ews() && 'c' == yArr.ordering() &&
|
||||||
1 == zArr.ews() && 'c' == zArr.ordering());
|
1 == zArr.ews() && 'c' == zArr.ordering());
|
||||||
|
|
||||||
if (bSpecialCase) {
|
if (bSpecialCase && yArr.isColumnVector() && 1 == xArr.sizeAt(-1) ) {
|
||||||
auto yLen = (uint32_t)yArr.lengthOf();
|
auto yLen = (uint32_t)yArr.lengthOf();
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (uint32_t i = start; i < stop; i++) {
|
for (uint32_t i = start; i < stop; i++) {
|
||||||
auto rZ = z + (i * yLen);
|
auto rZ = z + (i * yLen);
|
||||||
auto v = x[i];
|
auto v = x[i];
|
||||||
for (uint32_t j = 0; j < yLen; j++) {
|
for (uint32_t j = 0; j < yLen; j++) {
|
||||||
rZ[j] = OpType::op(v, y[j]);
|
rZ[j] = OpType::op(v, y[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
|
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
auto yShapeInt = yArr.getShapeAsVectorInt();
|
||||||
|
auto xShapeInt = xArr.getShapeAsVectorInt();
|
||||||
|
auto nCountY = std::count_if(yShapeInt.cbegin(), yShapeInt.cend(), [](int i) { return i == 1; });
|
||||||
|
auto nCountX = std::count_if(xShapeInt.cbegin(), xShapeInt.cend(), [](int i) { return i == 1; });
|
||||||
|
|
||||||
|
bool bSpecialCase2 = (xRank == zRank && yRank == zRank && 1 == xArr.sizeAt(-1) && 1 == yArr.sizeAt(-2) && 1 == nCountY && 1 == nCountX);
|
||||||
|
|
||||||
|
if (bSpecialCase && bSpecialCase2) {
|
||||||
|
|
||||||
|
int zDim1 = zArr.sizeAt(-2);
|
||||||
|
int zDim2 = zArr.sizeAt(-1);
|
||||||
|
|
||||||
|
int nLen = zArr.lengthOf() / yArr.sizeAt(-1);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (uint32_t total = start; total < stop; total++) {
|
||||||
|
|
||||||
|
uint32_t i = total / zDim1;
|
||||||
|
uint32_t j = total % zDim1;
|
||||||
|
|
||||||
|
uint32_t index = (i * zDim1) + j;
|
||||||
|
auto rZ = z + (index * zDim2);
|
||||||
|
auto rY = y + (i * zDim2);
|
||||||
|
auto rX = x[index];
|
||||||
|
|
||||||
|
for (uint32_t n = 0; n < zDim2; n++) {
|
||||||
|
rZ[n] = OpType::op(rX, rY[n]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(func, 0, nLen, 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const Nd4jLong zLen = zArr.lengthOf();
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
|
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||||
|
|
||||||
|
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||||
|
|
||||||
|
if (ix >= 0) {
|
||||||
|
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||||
|
xCoords[ix--] = zCoords[iz];
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
xCoords[ix--] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iy >= 0) {
|
||||||
|
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||||
|
yCoords[iy--] = zCoords[iz];
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
yCoords[iy--] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||||
|
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||||
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||||
|
|
||||||
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename X, typename Y, typename Z>
|
||||||
|
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||||
|
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z>
|
||||||
|
template<typename OpType>
|
||||||
|
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||||
|
|
||||||
|
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||||
|
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||||
|
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||||
|
|
||||||
|
const auto xShapeInfo = xArr.getShapeInfo();
|
||||||
|
const auto yShapeInfo = yArr.getShapeInfo();
|
||||||
|
const auto zShapeInfo = zArr.getShapeInfo();
|
||||||
|
|
||||||
|
const int xRank = xArr.rankOf();
|
||||||
|
const int yRank = yArr.rankOf();
|
||||||
|
const int zRank = zArr.rankOf();
|
||||||
|
|
||||||
|
const Nd4jLong zLen = zArr.lengthOf();
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
|
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||||
|
|
||||||
|
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||||
|
|
||||||
|
if (ix >= 0) {
|
||||||
|
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||||
|
xCoords[ix--] = zCoords[iz];
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
xCoords[ix--] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iy >= 0) {
|
||||||
|
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||||
|
yCoords[iy--] = zCoords[iz];
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
yCoords[iy--] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||||
|
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||||
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||||
|
|
||||||
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename X, typename Y>
|
||||||
|
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||||
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X>
|
||||||
|
template<typename OpType>
|
||||||
|
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||||
|
|
||||||
|
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||||
|
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||||
|
X* z = reinterpret_cast<X*>(zArr.getBuffer());
|
||||||
|
|
||||||
|
const auto xShapeInfo = xArr.getShapeInfo();
|
||||||
|
const auto yShapeInfo = yArr.getShapeInfo();
|
||||||
|
const auto zShapeInfo = zArr.getShapeInfo();
|
||||||
|
|
||||||
|
const int xRank = xArr.rankOf();
|
||||||
|
const int yRank = yArr.rankOf();
|
||||||
|
const int zRank = zArr.rankOf();
|
||||||
|
|
||||||
|
const Nd4jLong zLen = zArr.lengthOf();
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
|
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||||
|
|
||||||
|
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||||
|
|
||||||
|
if (ix >= 0) {
|
||||||
|
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||||
|
xCoords[ix--] = zCoords[iz];
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
xCoords[ix--] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iy >= 0) {
|
||||||
|
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||||
|
yCoords[iy--] = zCoords[iz];
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
yCoords[iy--] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||||
|
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||||
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||||
|
|
||||||
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename X>
|
||||||
|
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||||
|
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
|
||||||
|
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
const Nd4jLong zLen = zArr.lengthOf();
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
|
||||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
|
||||||
|
|
||||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
|
||||||
|
|
||||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
|
||||||
|
|
||||||
if (ix >= 0) {
|
|
||||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
|
||||||
xCoords[ix--] = zCoords[iz];
|
|
||||||
} else {
|
|
||||||
xCoords[ix--] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (iy >= 0) {
|
|
||||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
|
||||||
yCoords[iy--] = zCoords[iz];
|
|
||||||
} else {
|
|
||||||
yCoords[iy--] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
|
||||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, zLen);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
|
||||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename X, typename Z>
|
|
||||||
template<typename OpType>
|
|
||||||
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
|
||||||
|
|
||||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
|
||||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
|
||||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
|
||||||
|
|
||||||
const auto xShapeInfo = xArr.getShapeInfo();
|
|
||||||
const auto yShapeInfo = yArr.getShapeInfo();
|
|
||||||
const auto zShapeInfo = zArr.getShapeInfo();
|
|
||||||
|
|
||||||
const int xRank = xArr.rankOf();
|
|
||||||
const int yRank = yArr.rankOf();
|
|
||||||
const int zRank = zArr.rankOf();
|
|
||||||
|
|
||||||
const Nd4jLong zLen = zArr.lengthOf();
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
|
||||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
|
||||||
|
|
||||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
|
||||||
|
|
||||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
|
||||||
|
|
||||||
if (ix >= 0) {
|
|
||||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
|
||||||
xCoords[ix--] = zCoords[iz];
|
|
||||||
} else {
|
|
||||||
xCoords[ix--] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (iy >= 0) {
|
|
||||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
|
||||||
yCoords[iy--] = zCoords[iz];
|
|
||||||
} else {
|
|
||||||
yCoords[iy--] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
|
||||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, zLen);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename X>
|
|
||||||
template<typename OpType>
|
|
||||||
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
|
||||||
|
|
||||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
|
||||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
|
||||||
X* z = reinterpret_cast<X*>(zArr.getBuffer());
|
|
||||||
|
|
||||||
const auto xShapeInfo = xArr.getShapeInfo();
|
|
||||||
const auto yShapeInfo = yArr.getShapeInfo();
|
|
||||||
const auto zShapeInfo = zArr.getShapeInfo();
|
|
||||||
|
|
||||||
const int xRank = xArr.rankOf();
|
|
||||||
const int yRank = yArr.rankOf();
|
|
||||||
const int zRank = zArr.rankOf();
|
|
||||||
|
|
||||||
const Nd4jLong zLen = zArr.lengthOf();
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
|
||||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
|
||||||
|
|
||||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
|
||||||
|
|
||||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
|
||||||
|
|
||||||
if (ix >= 0) {
|
|
||||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
|
||||||
xCoords[ix--] = zCoords[iz];
|
|
||||||
} else {
|
|
||||||
xCoords[ix--] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (iy >= 0) {
|
|
||||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
|
||||||
yCoords[iy--] = zCoords[iz];
|
|
||||||
} else {
|
|
||||||
yCoords[iy--] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
|
||||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, zLen);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X>
|
|
||||||
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
|
||||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
|
|
||||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -25,6 +25,7 @@
|
||||||
#include <LoopKind.h>
|
#include <LoopKind.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
#include <execution/Threads.h>
|
#include <execution/Threads.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
@ -75,6 +76,7 @@ namespace functions {
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset,
|
Nd4jLong *zTadOffset,
|
||||||
|
nd4j::LoopKind::Kind loopKind,
|
||||||
uint64_t start,
|
uint64_t start,
|
||||||
uint64_t stop) {
|
uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
||||||
|
@ -88,7 +90,7 @@ namespace functions {
|
||||||
xTadShapeInfo,
|
xTadShapeInfo,
|
||||||
xTadOffset,
|
xTadOffset,
|
||||||
zTadShapeInfo,
|
zTadShapeInfo,
|
||||||
zTadOffset, start, stop), BROADCAST_OPS);
|
zTadOffset, loopKind, start, stop), BROADCAST_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
template <typename X, typename Y, typename Z>
|
||||||
|
@ -105,6 +107,7 @@ namespace functions {
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset,
|
Nd4jLong *zTadOffset,
|
||||||
|
nd4j::LoopKind::Kind loopKind,
|
||||||
uint64_t start,
|
uint64_t start,
|
||||||
uint64_t stop) {
|
uint64_t stop) {
|
||||||
|
|
||||||
|
@ -142,7 +145,14 @@ namespace functions {
|
||||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
auto zEws = shape::elementWiseStride(zTadShapeInfo);
|
auto zEws = shape::elementWiseStride(zTadShapeInfo);
|
||||||
|
|
||||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
|
||||||
|
const nd4j::LoopKind::Kind kindOfLoop =
|
||||||
|
(loopKind == nd4j::LoopKind::BROADCAST_SCALAR_X ||
|
||||||
|
loopKind == nd4j::LoopKind::BROADCAST_SCALAR_Y ||
|
||||||
|
loopKind == nd4j::LoopKind::BROADCAST_3D ||
|
||||||
|
loopKind == nd4j::LoopKind::BROADCAST_4D ||
|
||||||
|
loopKind == nd4j::LoopKind::BROADCAST_5D)
|
||||||
|
? loopKind : nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
for (auto i = start; i < stop; i++) {
|
for (auto i = start; i < stop; i++) {
|
||||||
|
@ -163,6 +173,131 @@ namespace functions {
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
||||||
}
|
}
|
||||||
|
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_X){
|
||||||
|
// this loop effectively turns broadcast into series of scalar ops
|
||||||
|
auto loopLength = yShapeInfo[shape::rank(yShapeInfo)];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
auto oY = y + (i * loopLength);
|
||||||
|
auto oZ = z + (i * loopLength);
|
||||||
|
|
||||||
|
const auto oX = x[i];
|
||||||
|
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
|
for (unsigned int f = 0; f < loopLength; f++)
|
||||||
|
oZ[f] = OpType::op(oX, oY[f]);
|
||||||
|
}
|
||||||
|
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_Y){
|
||||||
|
// this loop effectively turns broadcast into series of scalar ops
|
||||||
|
auto loopLength = xShapeInfo[shape::rank(xShapeInfo)];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
auto oX = x + (i * loopLength);
|
||||||
|
auto oZ = z + (i * loopLength);
|
||||||
|
|
||||||
|
const auto oY = y[i];
|
||||||
|
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
|
for (unsigned int f = 0; f < loopLength; f++)
|
||||||
|
oZ[f] = OpType::op(oX[f], oY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_3D) {
|
||||||
|
|
||||||
|
int xRank = shape::rank(xShapeInfo);
|
||||||
|
int yRank = shape::rank(yShapeInfo);
|
||||||
|
|
||||||
|
auto xStrides = shape::stride(xShapeInfo);
|
||||||
|
auto zStrides = shape::stride(zShapeInfo);
|
||||||
|
|
||||||
|
Nd4jLong yStrides[3] = { 0,0,0 };
|
||||||
|
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
|
||||||
|
|
||||||
|
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
|
||||||
|
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
|
||||||
|
|
||||||
|
for (uint32_t index0 = start; index0 < stop; index0++) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
|
for (uint32_t index1 = 0; index1 < nSize1; index1++) {
|
||||||
|
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
|
||||||
|
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2);
|
||||||
|
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2);
|
||||||
|
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2);
|
||||||
|
*rZ = OpType::op(*rX, *rY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_4D) {
|
||||||
|
|
||||||
|
int xRank = shape::rank(xShapeInfo);
|
||||||
|
int yRank = shape::rank(yShapeInfo);
|
||||||
|
|
||||||
|
auto xStrides = shape::stride(xShapeInfo);
|
||||||
|
auto zStrides = shape::stride(zShapeInfo);
|
||||||
|
|
||||||
|
Nd4jLong yStrides[4] = { 0,0,0,0 };
|
||||||
|
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
|
||||||
|
|
||||||
|
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
|
||||||
|
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
|
||||||
|
uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3);
|
||||||
|
|
||||||
|
for (uint32_t i = start; i < stop; i++) {
|
||||||
|
|
||||||
|
uint32_t index0 = i / nSize1;
|
||||||
|
uint32_t index1 = i % nSize1;
|
||||||
|
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
|
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
|
||||||
|
for (uint32_t index3 = 0; index3 < nSize3; index3++) {
|
||||||
|
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3);
|
||||||
|
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3);
|
||||||
|
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3);
|
||||||
|
*rZ = OpType::op(*rX, *rY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_5D) {
|
||||||
|
|
||||||
|
int xRank = shape::rank(xShapeInfo);
|
||||||
|
int yRank = shape::rank(yShapeInfo);
|
||||||
|
|
||||||
|
auto xStrides = shape::stride(xShapeInfo);
|
||||||
|
auto zStrides = shape::stride(zShapeInfo);
|
||||||
|
|
||||||
|
Nd4jLong yStrides[5] = { 0,0,0,0,0 };
|
||||||
|
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
|
||||||
|
|
||||||
|
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
|
||||||
|
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
|
||||||
|
uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3);
|
||||||
|
uint32_t nSize4 = shape::sizeAt(zShapeInfo, 4);
|
||||||
|
|
||||||
|
for (uint32_t i = start; i < stop; i++) {
|
||||||
|
|
||||||
|
uint32_t index0 = i / nSize1;
|
||||||
|
uint32_t index1 = i % nSize1;
|
||||||
|
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
|
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
|
||||||
|
for (uint32_t index3 = 0; index3 < nSize3; index3++) {
|
||||||
|
for (uint32_t index4 = 0; index4 < nSize4; index4++) {
|
||||||
|
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3 + xStrides[4] * index4);
|
||||||
|
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3 + yStrides[4] * index4);
|
||||||
|
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3 + zStrides[4] * index4);
|
||||||
|
|
||||||
|
*rZ = OpType::op(*rX, *rY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
|
|
|
@ -73,7 +73,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
intermediatery[thread_id] = OpType::startingIndexValue(x);
|
intermediatery[thread_id] = OpType::startingIndexValue(x);
|
||||||
|
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
IndexValue<X> curr(x[i], i);
|
IndexValue<X> curr(x[i], i);
|
||||||
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
|
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
intermediatery[thread_id] = OpType::startingIndexValue(x);
|
intermediatery[thread_id] = OpType::startingIndexValue(x);
|
||||||
|
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
IndexValue<X> curr(x[offset], i);
|
IndexValue<X> curr(x[offset], i);
|
||||||
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
|
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
|
||||||
|
|
|
@ -75,7 +75,7 @@ namespace functions {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
|
z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
|
@ -93,7 +93,7 @@ namespace functions {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (uint64_t i = start; i < stop; i += increment) {
|
for (uint64_t i = start; i < stop; i++) {
|
||||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
|
z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
|
||||||
|
@ -111,7 +111,7 @@ namespace functions {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (uint64_t i = start; i < stop; i += increment) {
|
for (uint64_t i = start; i < stop; i++) {
|
||||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments);
|
z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments);
|
||||||
|
@ -129,7 +129,7 @@ namespace functions {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (uint64_t i = start; i < stop; i += increment) {
|
for (uint64_t i = start; i < stop; i++) {
|
||||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments);
|
z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments);
|
||||||
|
@ -149,7 +149,7 @@ namespace functions {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (uint64_t i = start; i < stop; i += increment) {
|
for (uint64_t i = start; i < stop; i++) {
|
||||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
|
@ -197,7 +197,7 @@ namespace functions {
|
||||||
else{
|
else{
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (uint64_t i = start; i < stop; i += increment) {
|
for (uint64_t i = start; i < stop; i++) {
|
||||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments);
|
z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
|
@ -213,7 +213,7 @@ namespace functions {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (uint64_t i = start; i < stop; i += increment) {
|
for (uint64_t i = start; i < stop; i++) {
|
||||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments);
|
z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments);
|
||||||
|
@ -255,7 +255,7 @@ namespace functions {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (uint64_t i = start; i < stop; i += increment) {
|
for (uint64_t i = start; i < stop; i++) {
|
||||||
auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
z[offset] = OpClass::op(i, length, rng, extraArguments);
|
z[offset] = OpClass::op(i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,7 +88,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -98,7 +98,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
||||||
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
||||||
}
|
}
|
||||||
|
@ -110,7 +110,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
||||||
|
|
|
@ -158,7 +158,7 @@ namespace functions {
|
||||||
const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeShapeInfo, tadShapeShapeInfoCast);
|
const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto r = start; r < stop; r += increment) {
|
for (auto r = start; r < stop; r++) {
|
||||||
|
|
||||||
auto tadOffsetForBlock = tadPack.primaryOffsets()[r];
|
auto tadOffsetForBlock = tadPack.primaryOffsets()[r];
|
||||||
auto tx = x + tadOffsetForBlock;
|
auto tx = x + tadOffsetForBlock;
|
||||||
|
|
|
@ -84,7 +84,7 @@ namespace functions {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
int totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||||
|
|
||||||
for (int i = tid; i < length; i += totalThreads)
|
for (int i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
|
|
|
@ -89,7 +89,7 @@ namespace functions {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
int totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||||
|
|
||||||
for (int i = tid; i < length; i += totalThreads)
|
for (int i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
|
|
|
@ -97,7 +97,7 @@ namespace functions {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
int totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||||
|
|
||||||
for (Nd4jLong i = tid; i < length; i += totalThreads)
|
for (Nd4jLong i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
|
|
|
@ -87,7 +87,7 @@ namespace functions {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
int totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||||
|
|
||||||
for (int i = tid; i < length; i += totalThreads)
|
for (int i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
|
|
|
@ -89,7 +89,7 @@ namespace functions {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
int totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||||
|
|
||||||
for (int i = tid; i < length; i += totalThreads)
|
for (int i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
|
|
|
@ -81,7 +81,7 @@ namespace nd4j {
|
||||||
|
|
||||||
// now we actually apply quantization
|
// now we actually apply quantization
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
rz[e] = static_cast<char>(nd4j::math::nd4j_round<float, char>( 1.0f * static_cast<float>(x[e]) / nd4j::math::nd4j_max<float>(amax, amin) * max_byte));
|
rz[e] = static_cast<char>(nd4j::math::nd4j_round<float, char>( 1.0f * static_cast<float>(x[e]) / nd4j::math::nd4j_max<float>(amax, amin) * max_byte));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -177,7 +177,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
|
||||||
int flimit = limit + 4;
|
int flimit = limit + 4;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e += increment) {
|
for (auto e = start; e < stop; e++) {
|
||||||
int el = x[e];
|
int el = x[e];
|
||||||
int ael = nd4j::math::nd4j_abs<int>(el) - 1;
|
int ael = nd4j::math::nd4j_abs<int>(el) - 1;
|
||||||
z[ael] += el > 0 ? static_cast<T>(threshold) : static_cast<T>(-threshold);
|
z[ael] += el > 0 ? static_cast<T>(threshold) : static_cast<T>(-threshold);
|
||||||
|
@ -202,7 +202,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
|
||||||
auto z = reinterpret_cast<T *>(dz);
|
auto z = reinterpret_cast<T *>(dz);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i++) {
|
||||||
z[i] = static_cast<T>(static_cast<float>(x[i]));
|
z[i] = static_cast<T>(static_cast<float>(x[i]));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -147,6 +147,9 @@ namespace nd4j {
|
||||||
// returns TRUE if this op allows in-place execution
|
// returns TRUE if this op allows in-place execution
|
||||||
bool allowsInplace();
|
bool allowsInplace();
|
||||||
|
|
||||||
|
// this method allows you to enable/disable inplace call for a given op
|
||||||
|
void allowInplace(bool reallyAllow);
|
||||||
|
|
||||||
// this method returns opNum (applicable for legacy XYZ ops only)
|
// this method returns opNum (applicable for legacy XYZ ops only)
|
||||||
int getOpNum();
|
int getOpNum();
|
||||||
|
|
||||||
|
|
|
@ -27,12 +27,10 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(identity, 1, 1, true) {
|
OP_IMPL(identity, 1, 1, true) {
|
||||||
auto first = INPUT_VARIABLE(0);
|
auto first = INPUT_VARIABLE(0);
|
||||||
auto z = this->getZ(block);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
// just for lulz
|
if (!block.isInplace())
|
||||||
first->applyTransform(nd4j::transform::Identity, *z);
|
first->applyTransform(nd4j::transform::Identity, *z);
|
||||||
|
|
||||||
STORE_RESULT(*z);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -60,8 +58,8 @@ namespace nd4j {
|
||||||
DECLARE_TYPES(identity_bp) {
|
DECLARE_TYPES(identity_bp) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, DataType::ANY)
|
->setAllowedInputTypes(0, DataType::ANY)
|
||||||
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||||
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
|
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,142 +29,128 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
//////////////////////////////////////////////////////////////////////
|
||||||
auto x = INPUT_VARIABLE(0);
|
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
||||||
auto y = INPUT_VARIABLE(1);
|
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
const int iSize = (int) block.getIArguments()->size();
|
auto x = INPUT_VARIABLE(0);
|
||||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
auto y = INPUT_VARIABLE(1);
|
||||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
|
||||||
|
|
||||||
const int xRank = x->rankOf();
|
const int iSize = (int) block.getIArguments()->size();
|
||||||
const int yRank = y->rankOf();
|
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||||
const int zRank = z->rankOf();
|
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||||
|
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||||
|
|
||||||
if (transZ) {
|
const int xRank = x->rankOf();
|
||||||
x = INPUT_VARIABLE(1);
|
const int yRank = y->rankOf();
|
||||||
y = INPUT_VARIABLE(0);
|
const int zRank = z->rankOf();
|
||||||
bool temp = transX;
|
|
||||||
transX = !transY;
|
|
||||||
transY = !temp;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int xLastDim = transX ? -2 : -1;
|
if (transZ) {
|
||||||
const int yLastDim = transY ? -2 : -1;
|
x = INPUT_VARIABLE(1);
|
||||||
const int xLastButOneDim = transX ? -1 : -2;
|
y = INPUT_VARIABLE(0);
|
||||||
const int yLastButOneDim = transY ? -1 : -2;
|
bool temp = transX;
|
||||||
|
transX = !transY;
|
||||||
|
transY = !temp;
|
||||||
|
}
|
||||||
|
|
||||||
// ******* input validation ******* //
|
const int xLastDim = transX ? -2 : -1;
|
||||||
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0,
|
const int yLastDim = transY ? -2 : -1;
|
||||||
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
|
const int xLastButOneDim = transX ? -1 : -2;
|
||||||
xRank, yRank);
|
const int yLastButOneDim = transY ? -1 : -2;
|
||||||
|
|
||||||
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
|
// ******* input validation ******* //
|
||||||
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0,
|
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank);
|
||||||
"MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",
|
|
||||||
x->lengthOf(), y->lengthOf());
|
|
||||||
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
|
|
||||||
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0,
|
|
||||||
"MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !",
|
|
||||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
|
||||||
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
|
|
||||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0,
|
|
||||||
"MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !",
|
|
||||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
|
||||||
} else {
|
|
||||||
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0,
|
|
||||||
"MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !",
|
|
||||||
xRank, yRank, zRank);
|
|
||||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) &&
|
|
||||||
x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0,
|
|
||||||
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
|
|
||||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
|
|
||||||
ShapeUtils::shapeAsString(z).c_str());
|
|
||||||
|
|
||||||
if (xRank > 2) // outer dims must be the same
|
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
|
||||||
for (int i = 0; i < xRank - 2; ++i)
|
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", x->lengthOf(), y->lengthOf());
|
||||||
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0,
|
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
|
||||||
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
|
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
|
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
|
||||||
ShapeUtils::shapeAsString(z).c_str());
|
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||||
}
|
} else {
|
||||||
// ******* end of input validation ******* //
|
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank);
|
||||||
|
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
|
||||||
|
|
||||||
MmulHelper::matmul(x, y, z, transX, transY);
|
if (xRank > 2) // outer dims must be the same
|
||||||
|
for (int i = 0; i < xRank - 2; ++i)
|
||||||
|
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
|
||||||
|
}
|
||||||
|
// ******* end of input validation ******* //
|
||||||
|
|
||||||
return Status::OK();
|
MmulHelper::matmul(x, y, z, transX, transY);
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_SYN(mMul, matmul);
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
DECLARE_SYN(mmul, matmul);
|
DECLARE_SYN(mMul, matmul);
|
||||||
|
|
||||||
DECLARE_SYN(gemm, matmul);
|
DECLARE_SYN(mmul, matmul);
|
||||||
|
|
||||||
DECLARE_SYN(gemv, matmul);
|
DECLARE_SYN(gemm, matmul);
|
||||||
|
|
||||||
DECLARE_SYN(dot, matmul);
|
DECLARE_SYN(gemv, matmul);
|
||||||
|
|
||||||
|
DECLARE_SYN(dot, matmul);
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(matmul) {
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_SHAPE_FN(matmul) {
|
||||||
|
|
||||||
auto xShapeInfo = inputShape->at(0);
|
auto xShapeInfo = inputShape->at(0);
|
||||||
auto yShapeInfo = inputShape->at(1);
|
auto yShapeInfo = inputShape->at(1);
|
||||||
|
|
||||||
const int iSize = (int) block.getIArguments()->size();
|
const int iSize = (int) block.getIArguments()->size();
|
||||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||||
|
|
||||||
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
|
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
|
||||||
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
|
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
|
||||||
xShapeInfo[0], yShapeInfo[0]);
|
xShapeInfo[0], yShapeInfo[0]);
|
||||||
|
|
||||||
if (transZ) {
|
if (transZ) {
|
||||||
xShapeInfo = inputShape->at(1);
|
xShapeInfo = inputShape->at(1);
|
||||||
yShapeInfo = inputShape->at(0);
|
yShapeInfo = inputShape->at(0);
|
||||||
bool temp = transX;
|
bool temp = transX;
|
||||||
transX = !transY;
|
transX = !transY;
|
||||||
transY = !temp;
|
transY = !temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
|
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
|
||||||
|
|
||||||
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
|
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
|
||||||
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
|
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
|
||||||
|
|
||||||
auto xOrder = shape::order(xShapeInfo);
|
auto xOrder = shape::order(xShapeInfo);
|
||||||
auto yOrder = shape::order(yShapeInfo);
|
auto yOrder = shape::order(yShapeInfo);
|
||||||
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
|
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
|
||||||
|
|
||||||
// we just pick the higher data type out of X and Y
|
// we just pick the higher data type out of X and Y
|
||||||
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
|
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
|
||||||
|
|
||||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
|
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(newShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(matmul) {
|
//////////////////////////////////////////////////////////////////////
|
||||||
getOpDescriptor()
|
DECLARE_TYPES(matmul) {
|
||||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||||
}
|
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
auto eps = INPUT_VARIABLE(2);
|
||||||
|
auto dldx = OUTPUT_VARIABLE(0);
|
||||||
|
auto dldy = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
|
const int iSize = (int) block.getIArguments()->size();
|
||||||
auto x = INPUT_VARIABLE(0);
|
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||||
auto y = INPUT_VARIABLE(1);
|
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||||
auto eps = INPUT_VARIABLE(2);
|
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||||
auto dldx = OUTPUT_VARIABLE(0);
|
|
||||||
auto dldy = OUTPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
const int iSize = (int) block.getIArguments()->size();
|
|
||||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
|
||||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
|
||||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
In: x=[a,b], y=[b,c]
|
In: x=[a,b], y=[b,c]
|
||||||
|
@ -177,34 +163,35 @@ F F T [a,b] [b,c] [c,a] [c,a]
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
|
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
|
||||||
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
|
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_SHAPE_FN(matmul_bp) {
|
||||||
|
Nd4jLong *xShapeInfo;
|
||||||
|
Nd4jLong *yShapeInfo;
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(matmul_bp) {
|
COPY_SHAPE(inputShape->at(0), xShapeInfo);
|
||||||
Nd4jLong *xShapeInfo;
|
COPY_SHAPE(inputShape->at(1), yShapeInfo);
|
||||||
Nd4jLong *yShapeInfo;
|
|
||||||
|
|
||||||
COPY_SHAPE(inputShape->at(0), xShapeInfo);
|
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
|
||||||
COPY_SHAPE(inputShape->at(1), yShapeInfo);
|
}
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
|
//////////////////////////////////////////////////////////////////////
|
||||||
}
|
DECLARE_TYPES(matmul_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(2, {ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes(0, {ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes(1, {ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(matmul_bp) {
|
}
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
|
||||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
|
||||||
->setAllowedInputTypes(2, {ALL_FLOATS})
|
|
||||||
->setAllowedOutputTypes(0, {ALL_FLOATS})
|
|
||||||
->setAllowedOutputTypes(1, {ALL_FLOATS});
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,70 +21,174 @@
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_tensormmul)
|
#if NOT_EXCLUDED(OP_tensormmul)
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
#include <helpers/ShapeUtils.h>
|
#include <helpers/ShapeUtils.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <MmulHelper.h>
|
#include <MmulHelper.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
|
|
||||||
auto a = INPUT_VARIABLE(0);
|
|
||||||
auto b = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
auto c = OUTPUT_VARIABLE(0); //
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
|
||||||
|
|
||||||
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
|
auto a = INPUT_VARIABLE(0);
|
||||||
|
auto b = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
// building axes
|
auto c = OUTPUT_VARIABLE(0);
|
||||||
int axe0_size = INT_ARG(0);
|
|
||||||
int axe1_size = INT_ARG(axe0_size+1);
|
|
||||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
|
||||||
for (int e = 0; e < axe0_size; e++)
|
|
||||||
axes_0[e] = (int) INT_ARG(e+1);
|
|
||||||
|
|
||||||
for (int e = 0; e < axe1_size; e++)
|
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
|
||||||
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
|
|
||||||
|
|
||||||
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
|
// building axes
|
||||||
|
int axe0_size = INT_ARG(0);
|
||||||
|
int axe1_size = INT_ARG(axe0_size+1);
|
||||||
|
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||||
|
for (int e = 0; e < axe0_size; e++)
|
||||||
|
axes_0[e] = (int)INT_ARG(e + 1);
|
||||||
|
|
||||||
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
|
for (int e = 0; e < axe1_size; e++)
|
||||||
return Status::OK();
|
axes_1[e] = (int)INT_ARG(e + axe0_size + 2);
|
||||||
}
|
|
||||||
DECLARE_SYN(tensordot, tensormmul);
|
|
||||||
|
|
||||||
|
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(tensormmul) {
|
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
DECLARE_SYN(tensordot, tensormmul);
|
||||||
|
|
||||||
auto aShapeInfo = inputShape->at(0);
|
////////////////////////////////////////////////////////////////////////
|
||||||
auto bShapeInfo = inputShape->at(1);
|
DECLARE_SHAPE_FN(tensormmul) {
|
||||||
|
|
||||||
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
|
auto aShapeInfo = inputShape->at(0);
|
||||||
|
auto bShapeInfo = inputShape->at(1);
|
||||||
|
|
||||||
// building axes
|
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
|
||||||
int axe0_size = INT_ARG(0);
|
|
||||||
int axe1_size = INT_ARG(axe0_size+1);
|
|
||||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
|
||||||
for (int e = 0; e < axe0_size; e++)
|
|
||||||
axes_0[e] = (int) INT_ARG(e+1);
|
|
||||||
|
|
||||||
for (int e = 0; e < axe1_size; e++)
|
// building axes
|
||||||
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
|
int axe0_size = INT_ARG(0);
|
||||||
|
int axe1_size = INT_ARG(axe0_size+1);
|
||||||
|
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||||
|
for (int e = 0; e < axe0_size; e++)
|
||||||
|
axes_0[e] = (int) INT_ARG(e+1);
|
||||||
|
|
||||||
// evaluate shapes
|
for (int e = 0; e < axe1_size; e++)
|
||||||
std::vector<int> permutAt, permutBt;
|
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
|
||||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
|
||||||
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
|
// evaluate shapes
|
||||||
}
|
std::vector<int> permutAt, permutBt;
|
||||||
|
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||||
|
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||||
|
|
||||||
DECLARE_TYPES(tensormmul) {
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
|
||||||
getOpDescriptor()
|
}
|
||||||
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
|
||||||
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
////////////////////////////////////////////////////////////////////////
|
||||||
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
|
DECLARE_TYPES(tensormmul) {
|
||||||
}
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||||
|
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||||
|
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
|
||||||
|
|
||||||
|
auto A = INPUT_VARIABLE(0);
|
||||||
|
auto B = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto dLdC = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
auto dLdA = OUTPUT_VARIABLE(0);
|
||||||
|
auto dLdB = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
|
||||||
|
|
||||||
|
int axe0Size = INT_ARG(0);
|
||||||
|
int axe1Size = INT_ARG(axe0Size + 1);
|
||||||
|
|
||||||
|
auto Arank = A->rankOf();
|
||||||
|
auto Brank = B->rankOf();
|
||||||
|
auto dLdCrank = dLdC->rankOf();
|
||||||
|
|
||||||
|
REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0");
|
||||||
|
|
||||||
|
REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1");
|
||||||
|
|
||||||
|
// building axes
|
||||||
|
std::vector<int> axes0(axe0Size), axes1(axe1Size);
|
||||||
|
for (uint e = 0; e < axe0Size; e++)
|
||||||
|
axes0[e] = (int)INT_ARG(e + 1);
|
||||||
|
for (uint e = 0; e < axe1Size; e++)
|
||||||
|
axes1[e] = (int)INT_ARG(e + axe0Size + 2);
|
||||||
|
|
||||||
|
std::vector<int> permutAt, permutBt;
|
||||||
|
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||||
|
|
||||||
|
ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt);
|
||||||
|
|
||||||
|
// special case for scalar value
|
||||||
|
if (dLdC->isScalar()) {
|
||||||
|
|
||||||
|
dLdA->assign((*dLdC) * *B);
|
||||||
|
dLdB->assign((*dLdC) * *A);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
|
||||||
|
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
|
||||||
|
|
||||||
|
// rank always have to be divided by 2
|
||||||
|
std::vector<int> axesAdLdC, axesBdLdC;
|
||||||
|
if (dLdCrank > 1) {
|
||||||
|
axesAdLdC.resize(dLdCrank / 2);
|
||||||
|
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
|
||||||
|
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
axesAdLdC.push_back(0);
|
||||||
|
axesBdLdC.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate dLdA
|
||||||
|
MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt);
|
||||||
|
|
||||||
|
// calculate dLdB
|
||||||
|
MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_SHAPE_FN(tensormmul_bp) {
|
||||||
|
|
||||||
|
auto aShapeInfo = inputShape->at(0);
|
||||||
|
auto bShapeInfo = inputShape->at(1);
|
||||||
|
auto dLShapeInfo = inputShape->at(2);
|
||||||
|
|
||||||
|
REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) &&
|
||||||
|
(ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
|
||||||
|
|
||||||
|
Nd4jLong* dLdAShapeInfo = nullptr;
|
||||||
|
Nd4jLong* dLdBShapeInfo = nullptr;
|
||||||
|
|
||||||
|
COPY_SHAPE(aShapeInfo, dLdAShapeInfo);
|
||||||
|
COPY_SHAPE(bShapeInfo, dLdBShapeInfo);
|
||||||
|
|
||||||
|
return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_TYPES(tensormmul_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS
|
||||||
|
->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||||
|
->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||||
|
->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||||
|
->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF });
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -79,7 +79,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
||||||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput, false);
|
||||||
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
nd4j::ops::conv2d conv2d;
|
nd4j::ops::conv2d conv2d;
|
||||||
|
@ -216,10 +216,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
||||||
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput);
|
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput, false);
|
||||||
auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO);
|
auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO);
|
||||||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
nd4j::ops::conv2d_bp conv2dBP;
|
nd4j::ops::conv2d_bp conv2dBP;
|
||||||
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||||
|
|
|
@ -239,7 +239,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||||
//----- calculation of gradO -----//
|
//----- calculation of gradO -----//
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
|
gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
|
||||||
if(gradB != OUTPUT_VARIABLE(2))
|
if(gradB != OUTPUT_VARIABLE(2))
|
||||||
delete gradB;
|
delete gradB;
|
||||||
|
|
|
@ -233,7 +233,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
||||||
// ----- calculation of gradB ----- //
|
// ----- calculation of gradB ----- //
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}));
|
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
|
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
|
||||||
if(gradB != OUTPUT_VARIABLE(2))
|
if(gradB != OUTPUT_VARIABLE(2))
|
||||||
delete gradB;
|
delete gradB;
|
||||||
|
|
|
@ -243,7 +243,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
||||||
// ----- calculation of gradB ----- //
|
// ----- calculation of gradB ----- //
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
|
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
|
||||||
if(gradB != OUTPUT_VARIABLE(2))
|
if(gradB != OUTPUT_VARIABLE(2))
|
||||||
delete gradB;
|
delete gradB;
|
||||||
|
|
|
@ -31,22 +31,17 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf());
|
REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf());
|
||||||
REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf());
|
REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf());
|
||||||
REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1));
|
REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1));
|
||||||
REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!",
|
REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", x->sizeAt(1), w->sizeAt(0));
|
||||||
x->sizeAt(1), w->sizeAt(0));
|
|
||||||
|
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
//T bound = (T)0.f;
|
|
||||||
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
|
|
||||||
|
|
||||||
nd4j::ops::xw_plus_b op;
|
nd4j::ops::xw_plus_b op;
|
||||||
std::unique_ptr<ResultSet> result(op.evaluate({x, w, b}));
|
auto status = op.execute({x, w, b}, {output});
|
||||||
REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data.");
|
REQUIRE_TRUE(Status::OK() == status, 0, "relu_layer: xw_plus_b op failed on input data.");
|
||||||
|
|
||||||
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
|
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
|
||||||
|
|
||||||
auto xw = result->at(0);
|
output->applyScalar(nd4j::scalar::RELU, scalar, *output);
|
||||||
xw->applyScalar(nd4j::scalar::RELU, scalar, *output);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,8 @@
|
||||||
|
|
||||||
//#include <ops/declarable/headers/parity_ops.h>
|
//#include <ops/declarable/headers/parity_ops.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <ops/declarable/helpers/image_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(crop_and_resize, 4, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(crop_and_resize, 4, 1, false, 0, 0) {
|
||||||
|
|
|
@ -61,7 +61,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||||
|
|
||||||
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
|
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
||||||
|
|
||||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||||
|
|
||||||
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
|
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
||||||
|
|
||||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||||
|
|
||||||
if (block.width() > 1) {
|
if (block.width() > 1) {
|
||||||
auto newImageSize = INPUT_VARIABLE(1);
|
auto newImageSize = INPUT_VARIABLE(1);
|
||||||
|
|
|
@ -63,7 +63,7 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
|
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
|
||||||
|
|
||||||
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||||
|
|
||||||
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,11 +47,12 @@ namespace nd4j {
|
||||||
|
|
||||||
shape.insert(shape.begin() + axis, 1);
|
shape.insert(shape.begin() + axis, 1);
|
||||||
|
|
||||||
auto tmp = input->reshape(input->ordering(), shape);
|
if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) {
|
||||||
output->assign(tmp);
|
output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset());
|
||||||
|
} else {
|
||||||
STORE_RESULT(output);
|
auto tmp = input->reshape(input->ordering(), shape);
|
||||||
|
output->assign(tmp);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 29/10/17.
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
|
@ -29,80 +30,52 @@ namespace nd4j {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// here iArgs is int vector of ordered set of dimensions to be permuted
|
// here iArgs is int vector of ordered set of dimensions to be permuted
|
||||||
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
|
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
bool replace = false;
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
if (x->isEmpty()) {
|
||||||
std::vector<int> arguments({});
|
REQUIRE_TRUE(z->isEmpty(), 0, "PERMUTE OP: when input is empty, output must also be empty");
|
||||||
if(origArgs.size() > 0){
|
return Status::OK(); //No op
|
||||||
for (int e = 0; e < origArgs.size(); e++) {
|
|
||||||
int ax = origArgs[e];
|
|
||||||
if (ax < 0)
|
|
||||||
ax += x->rankOf();
|
|
||||||
|
|
||||||
arguments.emplace_back(ax);
|
|
||||||
}
|
|
||||||
|
|
||||||
replace = true;
|
|
||||||
} else {
|
|
||||||
for (int e = x->rankOf() - 1; e >= 0; e--)
|
|
||||||
arguments.emplace_back(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 0D edge case
|
|
||||||
if (x->rankOf() == 0) {
|
|
||||||
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
if (!block.isInplace())
|
|
||||||
output->assign(x);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if(block.isInplace()) { // in-place
|
|
||||||
x->permutei(arguments);
|
|
||||||
STORE_RESULT(x);
|
|
||||||
} else {
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
auto result = x->permute(arguments);
|
|
||||||
output->assign(result);
|
|
||||||
STORE_RESULT(output);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_TYPES(permute) {
|
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
|
||||||
->setAllowedInputTypes(1, {ALL_INTS})
|
|
||||||
->setSameMode(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(permute) {
|
|
||||||
auto shapeList = SHAPELIST();
|
|
||||||
auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
|
||||||
|
|
||||||
if (shape::rank(inputShape->at(0)) == 0) {
|
|
||||||
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
|
|
||||||
} else if (inputShape->size() == 1 && !arguments.empty()) {
|
|
||||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
|
||||||
} else {
|
|
||||||
if(arguments.size() == 0){
|
|
||||||
//Reverse dimensions
|
|
||||||
int rank = shape::rank(inputShape->at(0));
|
|
||||||
for (int e = rank - 1; e >= 0; e--)
|
|
||||||
arguments.emplace_back(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
|
||||||
}
|
|
||||||
|
|
||||||
return shapeList;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (block.width() == 1 && block.getIArguments()->size() == 0) {
|
||||||
|
z->assign(x->transpose());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||||
|
|
||||||
|
z->assign(x->permute(permutationVector));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_TYPES(permute) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_SHAPE_FN(permute) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (block.width() == 1 && block.getIArguments()->size() == 0)
|
||||||
|
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
|
||||||
|
|
||||||
|
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||||
|
|
||||||
|
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
|
||||||
|
|
||||||
|
return SHAPELIST(outputShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -24,254 +24,240 @@
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
// here iArgs is a vector with (optional) negative of order as first element:
|
|
||||||
// ({-order, dim1, dim2, dim3, ...})
|
|
||||||
CUSTOM_OP_IMPL(reshape, 1, 1, true, 0, -2) {
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
if (block.width() == 1) {
|
//////////////////////////////////////////////////////////////////////////
|
||||||
auto arguments = block.getIArguments();
|
// here iArgs is a vector with (optional) negative of order as first element:
|
||||||
int argsSize = arguments->size();
|
// ({-order, dim1, dim2, dim3, ...})
|
||||||
|
CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||||
|
|
||||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
auto x = INPUT_VARIABLE(0);
|
||||||
if (x->isEmpty()) {
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
|
||||||
return ND4J_STATUS_OK; //No op
|
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||||
|
if (x->isEmpty()) {
|
||||||
|
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||||
|
return Status::OK(); //No op
|
||||||
|
}
|
||||||
|
|
||||||
|
if (block.width() == 1) {
|
||||||
|
|
||||||
|
auto arguments = block.getIArguments();
|
||||||
|
int argsSize = arguments->size();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int e = 1;
|
||||||
|
char order = (char) -(*arguments)[0];
|
||||||
|
if (order != 'c' && order != 'f') {
|
||||||
|
order = 'c'; //x->ordering();
|
||||||
|
e = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> shapeNew;
|
||||||
|
int e2 = e;
|
||||||
|
for (; e < (int) arguments->size(); e++) {
|
||||||
|
if (arguments->at(e) == -1){
|
||||||
|
Nd4jLong shapeLength = 1;
|
||||||
|
for(; e2 < e; e2++){
|
||||||
|
shapeLength *= arguments->at(e2);
|
||||||
}
|
}
|
||||||
|
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
||||||
int e = 1;
|
shapeLength *= arguments->at(e2);
|
||||||
char order = (char) -(*arguments)[0];
|
|
||||||
if (order != 'c' && order != 'f') {
|
|
||||||
order = 'c'; //x->ordering();
|
|
||||||
e = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew;
|
|
||||||
int e2 = e;
|
|
||||||
for (; e < (int) arguments->size(); e++) {
|
|
||||||
if (arguments->at(e) == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(; e2 < e; e2++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
|
||||||
shapeNew.push_back(realShape);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew.push_back(arguments->at(e));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
|
||||||
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
|
||||||
|
|
||||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (block.isInplace()) {
|
|
||||||
if (x->reshapei(order, shapeNew)) {
|
|
||||||
STORE_RESULT(*x);
|
|
||||||
return ND4J_STATUS_OK;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto ret = OUTPUT_VARIABLE(0);
|
|
||||||
auto xr = x->reshape(order, shapeNew);
|
|
||||||
ret->assign(xr);
|
|
||||||
STORE_RESULT(*ret);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} else if (block.width() == 2) {
|
|
||||||
auto s = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
//Special case: empty.reshape(-1) -> return empty
|
|
||||||
if (x->isEmpty()) {
|
|
||||||
//REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
|
||||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
|
||||||
return Status::OK(); //No op
|
|
||||||
}
|
|
||||||
|
|
||||||
char order = 'c';
|
|
||||||
if (block.numI() > 0)
|
|
||||||
order = (char) -INT_ARG(0);
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew(s->lengthOf());
|
|
||||||
|
|
||||||
for (int e = 0; e < (int) s->lengthOf(); e++) {
|
|
||||||
auto dim = s->e<Nd4jLong >(e);
|
|
||||||
if (dim == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(int e2 = 0; e2 < e; e2++){
|
|
||||||
shapeLength *= s->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
|
||||||
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= s->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
|
||||||
shapeNew[e] = realShape;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew[e] = dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (block.isInplace()) {
|
|
||||||
if (x->reshapei(order, shapeNew)) {
|
|
||||||
STORE_RESULT(*x);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto ret = OUTPUT_VARIABLE(0);
|
|
||||||
if (s->isEmpty()) {
|
|
||||||
// just a scalar
|
|
||||||
ret->assign(x);
|
|
||||||
} else {
|
|
||||||
auto xr = x->reshape(order, shapeNew);
|
|
||||||
ret->assign(xr);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
||||||
|
shapeNew.push_back(realShape);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
shapeNew.push_back(arguments->at(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
return ND4J_STATUS_BAD_INPUT;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||||
|
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
||||||
|
|
||||||
DECLARE_TYPES(reshape) {
|
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||||
getOpDescriptor()
|
nd4j_printv("Reshape: new shape", shapeNew);
|
||||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
|
||||||
->setAllowedInputTypes(1, {ALL_INTS})
|
|
||||||
->setSameMode(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(reshape) {
|
auto xr = x->reshape(order, shapeNew);
|
||||||
auto inp = inputShape->at(0);
|
z->assign(xr);
|
||||||
|
STORE_RESULT(*z);
|
||||||
|
|
||||||
// we can launch op using Int arguments
|
return Status::OK();
|
||||||
if (inputShape->size() == 1) {
|
|
||||||
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
|
|
||||||
std::vector<int> *arguments = block.getIArguments();
|
|
||||||
|
|
||||||
int e = 1;
|
} else if (block.width() == 2) {
|
||||||
char order = (char) -(*arguments)[0];
|
|
||||||
if (order != 'c' && order != 'f') {
|
auto s = INPUT_VARIABLE(1);
|
||||||
order = shape::order(inp);
|
|
||||||
e = 0;
|
char order = 'c';
|
||||||
|
if (block.numI() > 0)
|
||||||
|
order = (char) -INT_ARG(0);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> shapeNew(s->lengthOf());
|
||||||
|
|
||||||
|
for (int e = 0; e < (int) s->lengthOf(); e++) {
|
||||||
|
auto dim = s->e<Nd4jLong >(e);
|
||||||
|
if (dim == -1){
|
||||||
|
Nd4jLong shapeLength = 1;
|
||||||
|
for(int e2 = 0; e2 < e; e2++){
|
||||||
|
shapeLength *= s->e<Nd4jLong>(e2);
|
||||||
}
|
}
|
||||||
|
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
||||||
std::vector<Nd4jLong> shapeNew;
|
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
|
shapeLength *= s->e<Nd4jLong>(e2);
|
||||||
int e2 = e;
|
|
||||||
for (; e < (int) arguments->size(); e++) {
|
|
||||||
if ((int) arguments->at(e) == -1){
|
|
||||||
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(; e2 < e; e2 ++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
|
||||||
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(shapeLength == 0){
|
|
||||||
//Edge case for empty:
|
|
||||||
shapeNew.push_back(0);
|
|
||||||
} else {
|
|
||||||
//Standard case
|
|
||||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
|
||||||
shapeNew.push_back(realShape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew.push_back(arguments->at(e));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
|
shapeNew[e] = realShape;
|
||||||
} else {
|
}
|
||||||
// or, with second input "as shape"
|
else{
|
||||||
auto x = INPUT_VARIABLE(0);
|
shapeNew[e] = dim;
|
||||||
auto y = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
// special case here
|
|
||||||
if (y->isEmpty()) {
|
|
||||||
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
|
|
||||||
}
|
|
||||||
//Special case: empty.reshape(-1) -> return empty
|
|
||||||
if (x->isEmpty()) {
|
|
||||||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
|
||||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
|
||||||
Nd4jLong prod = 1;
|
|
||||||
bool hasNegs = false;
|
|
||||||
for (auto v:shapeOf) {
|
|
||||||
if (v < 0) {
|
|
||||||
hasNegs = true;
|
|
||||||
v = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
prod *= v;
|
|
||||||
}
|
|
||||||
|
|
||||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
|
||||||
|
|
||||||
// if there are -1s - we turn them into zeros
|
|
||||||
if (hasNegs) {
|
|
||||||
for (int e = 0; e < shapeOf.size(); e++)
|
|
||||||
if (shapeOf[e] < 0)
|
|
||||||
shapeOf[e] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
|
||||||
return SHAPELIST(CONSTANT(newShape));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
|
||||||
|
|
||||||
for (int e = 0; e < (int) y->lengthOf(); e++) {
|
|
||||||
auto dim = y->e<Nd4jLong>(e);
|
|
||||||
if (dim == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(int e2 = 0; e2 < e; e2++){
|
|
||||||
shapeLength *= y->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
|
|
||||||
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= y->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(shapeLength == 0){
|
|
||||||
//Edge case for empty:
|
|
||||||
shapeNew[e] = 0;
|
|
||||||
} else {
|
|
||||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
|
||||||
shapeNew[e] = realShape;
|
|
||||||
}
|
|
||||||
}else {
|
|
||||||
shapeNew[e] = dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||||
|
nd4j_printv("Reshape: new shape", shapeNew);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (s->isEmpty()) {
|
||||||
|
// just a scalar
|
||||||
|
z->assign(x);
|
||||||
|
} else {
|
||||||
|
auto xr = x->reshape(order, shapeNew);
|
||||||
|
z->assign(xr);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return ND4J_STATUS_BAD_INPUT;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_TYPES(reshape) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(reshape) {
|
||||||
|
auto inp = inputShape->at(0);
|
||||||
|
|
||||||
|
// we can launch op using Int arguments
|
||||||
|
if (inputShape->size() == 1) {
|
||||||
|
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
|
||||||
|
std::vector<int> *arguments = block.getIArguments();
|
||||||
|
|
||||||
|
int e = 1;
|
||||||
|
char order = (char) -(*arguments)[0];
|
||||||
|
if (order != 'c' && order != 'f') {
|
||||||
|
order = shape::order(inp);
|
||||||
|
e = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> shapeNew;
|
||||||
|
|
||||||
|
int e2 = e;
|
||||||
|
for (; e < (int) arguments->size(); e++) {
|
||||||
|
if ((int) arguments->at(e) == -1){
|
||||||
|
|
||||||
|
Nd4jLong shapeLength = 1;
|
||||||
|
for(; e2 < e; e2 ++){
|
||||||
|
shapeLength *= arguments->at(e2);
|
||||||
|
}
|
||||||
|
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
||||||
|
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
|
shapeLength *= arguments->at(e2);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(shapeLength == 0){
|
||||||
|
//Edge case for empty:
|
||||||
|
shapeNew.push_back(0);
|
||||||
|
} else {
|
||||||
|
//Standard case
|
||||||
|
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
||||||
|
shapeNew.push_back(realShape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
shapeNew.push_back(arguments->at(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
|
||||||
|
} else {
|
||||||
|
// or, with second input "as shape"
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
// special case here
|
||||||
|
if (y->isEmpty()) {
|
||||||
|
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
|
||||||
|
}
|
||||||
|
//Special case: empty.reshape(-1) -> return empty
|
||||||
|
if (x->isEmpty()) {
|
||||||
|
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||||
|
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
||||||
|
Nd4jLong prod = 1;
|
||||||
|
bool hasNegs = false;
|
||||||
|
for (auto v:shapeOf) {
|
||||||
|
if (v < 0) {
|
||||||
|
hasNegs = true;
|
||||||
|
v = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
prod *= v;
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
||||||
|
|
||||||
|
// if there are -1s - we turn them into zeros
|
||||||
|
if (hasNegs) {
|
||||||
|
for (int e = 0; e < shapeOf.size(); e++)
|
||||||
|
if (shapeOf[e] < 0)
|
||||||
|
shapeOf[e] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
||||||
|
return SHAPELIST(CONSTANT(newShape));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
||||||
|
|
||||||
|
for (int e = 0; e < (int) y->lengthOf(); e++) {
|
||||||
|
auto dim = y->e<Nd4jLong>(e);
|
||||||
|
if (dim == -1){
|
||||||
|
Nd4jLong shapeLength = 1;
|
||||||
|
for(int e2 = 0; e2 < e; e2++){
|
||||||
|
shapeLength *= y->e<Nd4jLong>(e2);
|
||||||
|
}
|
||||||
|
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
|
||||||
|
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
|
shapeLength *= y->e<Nd4jLong>(e2);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(shapeLength == 0){
|
||||||
|
//Edge case for empty:
|
||||||
|
shapeNew[e] = 0;
|
||||||
|
} else {
|
||||||
|
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
||||||
|
shapeNew[e] = realShape;
|
||||||
|
}
|
||||||
|
}else {
|
||||||
|
shapeNew[e] = dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -28,18 +28,16 @@ namespace nd4j {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(reshapeas, 2, 1, true, 0, 0) {
|
CUSTOM_OP_IMPL(reshapeas, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
std::vector<Nd4jLong> shapeNew(y->shapeOf(), y->shapeOf() + y->rankOf());
|
|
||||||
char order = y->ordering();
|
|
||||||
|
|
||||||
if (x->reshapei(order, shapeNew)) {
|
if (x->reshapei(y->ordering(), y->getShapeAsVector())) {
|
||||||
*z = *x;
|
|
||||||
STORE_RESULT(*z);
|
z->assign(x);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,14 +47,8 @@ namespace nd4j {
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(reshapeas) {
|
DECLARE_SHAPE_FN(reshapeas) {
|
||||||
|
|
||||||
auto inputShapeInfo = inputShape->at(1);
|
return SHAPELIST(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->getShapeInfo(), false, block.workspace()));
|
||||||
int shapeInfoLength = inputShapeInfo[0]*2 + 4;
|
}
|
||||||
|
|
||||||
Nd4jLong* outputShapeInfo(nullptr);
|
|
||||||
COPY_SHAPE(inputShapeInfo, outputShapeInfo);
|
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(outputShapeInfo));
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_TYPES(reshapeas) {
|
DECLARE_TYPES(reshapeas) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(squeeze, 1, 1, true, 0, -2) {
|
CUSTOM_OP_IMPL(squeeze, 1, 1, false, 0, -2) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
@ -71,10 +71,14 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (block.isInplace()) {
|
if (block.isInplace()) {
|
||||||
output->reshapei(input->ordering(), shape);
|
output->reshapei(input->ordering(), shape, false);
|
||||||
} else {
|
} else {
|
||||||
auto tmp = input->reshape(input->ordering(), shape);
|
if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) {
|
||||||
output->assign(tmp);
|
output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset());
|
||||||
|
} else {
|
||||||
|
auto tmp = input->reshape(input->ordering(), shape);
|
||||||
|
output->assign(tmp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(tile_to_shape, 1, 1, true, 0, -1) {
|
CUSTOM_OP_IMPL(tile_to_shape, 1, 1, false, 0, -1) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 29/10/17.
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
|
@ -25,113 +26,52 @@
|
||||||
#include <helpers/ShapeUtils.h>
|
#include <helpers/ShapeUtils.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(transpose, 1, 1, true, 0, 0) {
|
CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
|
||||||
if (block.width() == 1) {
|
|
||||||
if (block.isInplace()) {
|
|
||||||
x->transposei();
|
|
||||||
STORE_RESULT(*x);
|
|
||||||
} else {
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
auto t = x->transpose();
|
|
||||||
output->assign(t);
|
|
||||||
STORE_RESULT(*output);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// this is tf-mode transpose, that's nd4j permute
|
|
||||||
bool replace = false;
|
|
||||||
std::vector<int> arguments(*block.getIArguments());
|
|
||||||
|
|
||||||
auto w = block.width();
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto a = arguments.size();
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
if (w == 2 && a == 0) {
|
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||||
auto axis = INPUT_VARIABLE(1);
|
if (x->isEmpty()) {
|
||||||
for (int e = 0; e < axis->lengthOf(); e++) {
|
REQUIRE_TRUE(z->isEmpty(), 0, "TRANSPOSE OP: when input is empty, output must also be empty");
|
||||||
auto ax = axis->e<int>(e);
|
return Status::OK(); //No op
|
||||||
if (ax < 0)
|
}
|
||||||
ax += x->rankOf();
|
|
||||||
|
|
||||||
arguments.emplace_back(ax);
|
if (block.width() == 1 && block.getIArguments()->size() == 0) {
|
||||||
}
|
z->assign(x->transpose());
|
||||||
|
|
||||||
replace = true;
|
|
||||||
} else if (a == 0) {
|
|
||||||
for (int e = x->rankOf() - 1; e >= 0; e--)
|
|
||||||
arguments.emplace_back(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 0D edge case
|
|
||||||
if (x->rankOf() == 0) {
|
|
||||||
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
if (!block.isInplace())
|
|
||||||
output->assign(x);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if(block.isInplace()) { // in-place
|
|
||||||
x->permutei(arguments);
|
|
||||||
STORE_RESULT(x);
|
|
||||||
} else {
|
|
||||||
auto input = x->permute(arguments);
|
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
output->assign(input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(transpose) {
|
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setSameMode(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(transpose) {
|
z->assign(x->permute(permutationVector));
|
||||||
if (block.width() == 1) {
|
|
||||||
auto outputShapeInfo = ShapeUtils::evalTranspShapeInfo(*INPUT_VARIABLE(0), block.workspace());
|
|
||||||
return SHAPELIST(outputShapeInfo);
|
|
||||||
} else {
|
|
||||||
// this is basically permute mode
|
|
||||||
auto shapeList = SHAPELIST();
|
|
||||||
auto arguments = block.getIArguments();
|
|
||||||
if (shape::rank(inputShape->at(0)) == 0) {
|
|
||||||
Nd4jLong *newshape;
|
|
||||||
ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape->at(0)), Nd4jLong);
|
|
||||||
newshape[0] = 0;
|
|
||||||
newshape[1] = 0;
|
|
||||||
newshape[2] = 1;
|
|
||||||
newshape[3] = 99;
|
|
||||||
ArrayOptions::copyDataType(newshape, inputShape->at(0));
|
|
||||||
shapeList->push_back(newshape);
|
|
||||||
} else if (arguments->size() > 0 || inputShape->size() > 1) {
|
|
||||||
auto axis = arguments->size() > 0 ? *arguments : (INPUT_VARIABLE(1))->template asVectorT<int>();
|
|
||||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(axis.data(), axis.size(), *INPUT_VARIABLE(0), block.workspace());
|
|
||||||
shapeList->push_back(outputShapeInfo);
|
|
||||||
} else if (inputShape->size() == 2) {
|
|
||||||
// dead end
|
|
||||||
auto axis = INPUT_VARIABLE(1);
|
|
||||||
auto axisV = axis->template asVectorT<Nd4jLong>();
|
|
||||||
auto newshape = ShapeUtils::evalPermShapeInfo(axisV.data(), axisV.size(), *INPUT_VARIABLE(0), block.workspace());
|
|
||||||
shapeList->push_back(newshape);
|
|
||||||
} else {
|
|
||||||
int rank = shape::rank(inputShape->at(0));
|
|
||||||
for (int e = rank - 1; e >= 0; e--)
|
|
||||||
arguments->emplace_back(e);
|
|
||||||
|
|
||||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace());
|
return Status::OK();
|
||||||
shapeList->push_back(outputShapeInfo);
|
}
|
||||||
}
|
|
||||||
|
DECLARE_TYPES(transpose) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(transpose) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (block.width() == 1 && block.getIArguments()->size() == 0)
|
||||||
|
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
|
||||||
|
|
||||||
|
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||||
|
|
||||||
|
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
|
||||||
|
|
||||||
|
return SHAPELIST(outputShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
return shapeList;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue