commit
e4ddf109c3
|
@ -1,47 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.util;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
|
||||
import java.awt.*;
|
||||
import java.net.URI;
|
||||
|
||||
/**
|
||||
* Various utilities for webpages and dealing with browsers
|
||||
*/
|
||||
public class WebUtils {
|
||||
|
||||
public static void tryOpenBrowser(String path, Logger log) {
|
||||
try {
|
||||
WebUtils.openBrowser(new URI(path));
|
||||
} catch (Exception e) {
|
||||
log.error("Could not open browser", e);
|
||||
System.out.println("Browser could not be launched automatically.\nUI path: " + path);
|
||||
}
|
||||
}
|
||||
|
||||
public static void openBrowser(URI uri) throws Exception {
|
||||
if (Desktop.isDesktopSupported()) {
|
||||
Desktop.getDesktop().browse(uri);
|
||||
} else {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot open browser on this platform: Desktop.isDesktopSupported() == false");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -127,7 +127,7 @@ public class BraninFunction {
|
|||
BraninConfig candidate = (BraninConfig) c.getValue();
|
||||
|
||||
double score = scoreFunction.score(candidate, null, (Map) null);
|
||||
System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
|
||||
// System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
|
||||
|
||||
Thread.sleep(20);
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ public class TestRandomSearch extends BaseDL4JTest {
|
|||
runner.execute();
|
||||
|
||||
|
||||
System.out.println("----- Complete -----");
|
||||
// System.out.println("----- Complete -----");
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
public class TestRandomGenerator implements RandomGenerator {
|
||||
private final int[] intRandomNumbers;
|
||||
|
@ -63,17 +63,17 @@ public class TestRandomGenerator implements RandomGenerator {
|
|||
|
||||
@Override
|
||||
public long nextLong() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nextBoolean() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public float nextFloat() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -83,6 +83,6 @@ public class TestRandomGenerator implements RandomGenerator {
|
|||
|
||||
@Override
|
||||
public double nextGaussian() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
|
||||
|
@ -26,7 +27,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
|||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
||||
|
||||
|
@ -42,7 +42,7 @@ public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public CrossoverResult crossover() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
|
||||
|
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.population.Populati
|
|||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -46,7 +46,7 @@ public class RatioCullOperatorTests extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public void cullPopulation() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
public double getCullRatio() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
|
@ -33,7 +34,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
|||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
|
@ -55,7 +55,7 @@ public class GeneticSelectionOperatorTests extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public void cullPopulation() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
|
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.Selection
|
|||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
public class SelectionOperatorTests extends BaseDL4JTest {
|
||||
private class TestSelectionOperator extends SelectionOperator {
|
||||
|
@ -39,7 +39,7 @@ public class SelectionOperatorTests extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public double[] buildNextGenes() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ public class TestComputationGraphSpace extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||
assertTrue(reluCount > 0);
|
||||
assertTrue(tanhCount > 0);
|
||||
|
||||
|
|
|
@ -162,7 +162,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
|||
List<ResultReference> results = runner.getResults();
|
||||
assertTrue(results.size() > 0);
|
||||
|
||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -165,7 +165,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
|||
List<ResultReference> results = runner.getResults();
|
||||
assertTrue(results.size() > 0);
|
||||
|
||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
|||
double l2 = TestUtils.getL2(l);
|
||||
IActivation activation = l.getActivationFn();
|
||||
|
||||
System.out.println(lr + "\t" + l2 + "\t" + activation);
|
||||
// System.out.println(lr + "\t" + l2 + "\t" + activation);
|
||||
|
||||
assertTrue(lr >= 0.3 && lr <= 0.4);
|
||||
assertTrue(l2 >= 0.01 && l2 <= 0.1);
|
||||
|
@ -190,7 +190,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
|||
ActivationLayer al = als.getValue(d);
|
||||
IActivation activation = al.getActivationFn();
|
||||
|
||||
System.out.println(activation);
|
||||
// System.out.println(activation);
|
||||
|
||||
assertTrue(containsActivationFunction(actFns, activation));
|
||||
}
|
||||
|
@ -228,7 +228,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
|||
IActivation activation = el.getActivationFn();
|
||||
long nOut = el.getNOut();
|
||||
|
||||
System.out.println(activation + "\t" + nOut);
|
||||
// System.out.println(activation + "\t" + nOut);
|
||||
|
||||
assertTrue(containsActivationFunction(actFns, activation));
|
||||
assertTrue(nOut >= 10 && nOut <= 20);
|
||||
|
@ -295,7 +295,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
|||
long nOut = el.getNOut();
|
||||
double forgetGate = el.getForgetGateBiasInit();
|
||||
|
||||
System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
|
||||
// System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
|
||||
|
||||
assertTrue(containsActivationFunction(actFns, activation));
|
||||
assertTrue(nOut >= 10 && nOut <= 20);
|
||||
|
|
|
@ -293,8 +293,8 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
|
|||
assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly
|
||||
}
|
||||
|
||||
System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
|
||||
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||
// System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
|
||||
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -98,7 +98,8 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
|||
assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson()));
|
||||
|
||||
FileUtils.writeStringToFile(new File(configPath),configuration.toJson());
|
||||
System.out.println(configuration.toJson());
|
||||
// System.out.println(configuration.toJson());
|
||||
configuration.toJson();
|
||||
|
||||
log.info("Starting test");
|
||||
cliRunner.runMain(
|
||||
|
|
|
@ -0,0 +1,390 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.regressiontest;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.graph.vertex.impl.MergeVertex;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.impl.*;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.RmsProp;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.lossfunctions.impl.LossMAE;
|
||||
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.DataInputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class RegressionTest100b6 extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCustomLayer() throws Exception {
|
||||
|
||||
for (DataType dtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
||||
|
||||
String dtypeName = dtype.toString().toLowerCase();
|
||||
|
||||
File f = Resources.asFile("regression_testing/100b6/CustomLayerExample_100b6_" + dtypeName + ".bin");
|
||||
MultiLayerNetwork.load(f, true);
|
||||
|
||||
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
||||
// net = net.clone();
|
||||
|
||||
DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer();
|
||||
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
|
||||
assertEquals(new RmsProp(0.95), l0.getIUpdater());
|
||||
|
||||
CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer();
|
||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
|
||||
assertEquals(new RmsProp(0.95), l1.getIUpdater());
|
||||
|
||||
INDArray outExp;
|
||||
File f2 = Resources
|
||||
.asFile("regression_testing/100b6/CustomLayerExample_Output_100b6_" + dtypeName + ".bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||
outExp = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray in;
|
||||
File f3 = Resources.asFile("regression_testing/100b6/CustomLayerExample_Input_100b6_" + dtypeName + ".bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||
in = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
assertEquals(dtype, in.dataType());
|
||||
assertEquals(dtype, outExp.dataType());
|
||||
assertEquals(dtype, net.params().dataType());
|
||||
assertEquals(dtype, net.getFlattenedGradients().dataType());
|
||||
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
|
||||
|
||||
//System.out.println(Arrays.toString(net.params().data().asFloat()));
|
||||
|
||||
INDArray outAct = net.output(in);
|
||||
assertEquals(dtype, outAct.dataType());
|
||||
|
||||
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
||||
assertEquals(dtype, net.params().dataType());
|
||||
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||
assertTrue(outExp + " vs " + outAct, eq); }
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testLSTM() throws Exception {
|
||||
|
||||
File f = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_100b6.bin");
|
||||
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
||||
|
||||
LSTM l0 = (LSTM) net.getLayer(0).conf().getLayer();
|
||||
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||
assertEquals(200, l0.getNOut());
|
||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
||||
|
||||
LSTM l1 = (LSTM) net.getLayer(1).conf().getLayer();
|
||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||
assertEquals(200, l1.getNOut());
|
||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
||||
|
||||
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).conf().getLayer();
|
||||
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
||||
assertEquals(77, l2.getNOut());
|
||||
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
||||
assertEquals(new Adam(0.005), l2.getIUpdater());
|
||||
|
||||
assertEquals(BackpropType.TruncatedBPTT, net.getLayerWiseConfigurations().getBackpropType());
|
||||
assertEquals(50, net.getLayerWiseConfigurations().getTbpttBackLength());
|
||||
assertEquals(50, net.getLayerWiseConfigurations().getTbpttFwdLength());
|
||||
|
||||
INDArray outExp;
|
||||
File f2 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Output_100b6.bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||
outExp = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray in;
|
||||
File f3 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Input_100b6.bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||
in = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray outAct = net.output(in);
|
||||
|
||||
assertEquals(outExp, outAct);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVae() throws Exception {
|
||||
|
||||
File f = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_100b6.bin");
|
||||
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
||||
|
||||
VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).conf().getLayer();
|
||||
assertEquals(new ActivationLReLU(), l0.getActivationFn());
|
||||
assertEquals(32, l0.getNOut());
|
||||
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
||||
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
||||
|
||||
INDArray outExp;
|
||||
File f2 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Output_100b6.bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||
outExp = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray in;
|
||||
File f3 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Input_100b6.bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||
in = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray outAct = net.output(in);
|
||||
|
||||
assertEquals(outExp, outAct);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testYoloHouseNumber() throws Exception {
|
||||
|
||||
File f = Resources.asFile("regression_testing/100b6/HouseNumberDetection_100b6.bin");
|
||||
ComputationGraph net = ComputationGraph.load(f, true);
|
||||
|
||||
int nBoxes = 5;
|
||||
int nClasses = 10;
|
||||
|
||||
ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getConfiguration().getVertices()
|
||||
.get("convolution2d_9")).getLayerConf().getLayer();
|
||||
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
||||
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
||||
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
||||
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
|
||||
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
|
||||
|
||||
INDArray outExp;
|
||||
File f2 = Resources.asFile("regression_testing/100b6/HouseNumberDetection_Output_100b6.bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||
outExp = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray in;
|
||||
File f3 = Resources.asFile("regression_testing/100b6/HouseNumberDetection_Input_100b6.bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||
in = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray outAct = net.outputSingle(in);
|
||||
|
||||
boolean eq = outExp.equalsWithEps(outAct.castTo(outExp.dataType()), 1e-3);
|
||||
assertTrue(eq);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSyntheticCNN() throws Exception {
|
||||
|
||||
File f = Resources.asFile("regression_testing/100b6/SyntheticCNN_100b6.bin");
|
||||
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
||||
|
||||
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).conf().getLayer();
|
||||
assertEquals(new ActivationReLU(), l0.getActivationFn());
|
||||
assertEquals(4, l0.getNOut());
|
||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
||||
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
||||
assertArrayEquals(new int[]{2, 1}, l0.getStride());
|
||||
assertArrayEquals(new int[]{1, 1}, l0.getDilation());
|
||||
assertArrayEquals(new int[]{0, 0}, l0.getPadding());
|
||||
|
||||
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).conf().getLayer();
|
||||
assertEquals(new ActivationReLU(), l1.getActivationFn());
|
||||
assertEquals(8, l1.getNOut());
|
||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
||||
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
||||
assertArrayEquals(new int[]{1, 1}, l1.getStride());
|
||||
assertArrayEquals(new int[]{1, 1}, l1.getDilation());
|
||||
assertArrayEquals(new int[]{0, 0}, l1.getPadding());
|
||||
assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
|
||||
assertEquals(1, l1.getDepthMultiplier());
|
||||
|
||||
SubsamplingLayer l2 = (SubsamplingLayer) net.getLayer(2).conf().getLayer();
|
||||
assertArrayEquals(new int[]{3, 3}, l2.getKernelSize());
|
||||
assertArrayEquals(new int[]{2, 2}, l2.getStride());
|
||||
assertArrayEquals(new int[]{1, 1}, l2.getDilation());
|
||||
assertArrayEquals(new int[]{0, 0}, l2.getPadding());
|
||||
assertEquals(PoolingType.MAX, l2.getPoolingType());
|
||||
|
||||
ZeroPaddingLayer l3 = (ZeroPaddingLayer) net.getLayer(3).conf().getLayer();
|
||||
assertArrayEquals(new int[]{4, 4, 4, 4}, l3.getPadding());
|
||||
|
||||
Upsampling2D l4 = (Upsampling2D) net.getLayer(4).conf().getLayer();
|
||||
assertArrayEquals(new int[]{3, 3}, l4.getSize());
|
||||
|
||||
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).conf().getLayer();
|
||||
assertEquals(new ActivationReLU(), l5.getActivationFn());
|
||||
assertEquals(16, l5.getNOut());
|
||||
assertEquals(new WeightInitXavier(), l5.getWeightInitFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
||||
assertEquals(new Adam(0.005), l5.getIUpdater());
|
||||
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
||||
assertArrayEquals(new int[]{1, 1}, l5.getStride());
|
||||
assertArrayEquals(new int[]{1, 1}, l5.getDilation());
|
||||
assertArrayEquals(new int[]{0, 0}, l5.getPadding());
|
||||
assertEquals(2, l5.getDepthMultiplier());
|
||||
|
||||
SubsamplingLayer l6 = (SubsamplingLayer) net.getLayer(6).conf().getLayer();
|
||||
assertArrayEquals(new int[]{2, 2}, l6.getKernelSize());
|
||||
assertArrayEquals(new int[]{2, 2}, l6.getStride());
|
||||
assertArrayEquals(new int[]{1, 1}, l6.getDilation());
|
||||
assertArrayEquals(new int[]{0, 0}, l6.getPadding());
|
||||
assertEquals(PoolingType.MAX, l6.getPoolingType());
|
||||
|
||||
Cropping2D l7 = (Cropping2D) net.getLayer(7).conf().getLayer();
|
||||
assertArrayEquals(new int[]{3, 3, 2, 2}, l7.getCropping());
|
||||
|
||||
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).conf().getLayer();
|
||||
assertEquals(4, l8.getNOut());
|
||||
assertEquals(new WeightInitXavier(), l8.getWeightInitFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
||||
assertEquals(new Adam(0.005), l8.getIUpdater());
|
||||
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
||||
assertArrayEquals(new int[]{1, 1}, l8.getStride());
|
||||
assertArrayEquals(new int[]{1, 1}, l8.getDilation());
|
||||
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
|
||||
|
||||
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).conf().getLayer();
|
||||
assertEquals(new WeightInitXavier(), l9.getWeightInitFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
||||
assertEquals(new Adam(0.005), l9.getIUpdater());
|
||||
assertEquals(new LossMAE(), l9.getLossFn());
|
||||
|
||||
INDArray outExp;
|
||||
File f2 = Resources.asFile("regression_testing/100b6/SyntheticCNN_Output_100b6.bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||
outExp = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray in;
|
||||
File f3 = Resources.asFile("regression_testing/100b6/SyntheticCNN_Input_100b6.bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||
in = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray outAct = net.output(in);
|
||||
|
||||
//19 layers - CPU vs. GPU difference accumulates notably, but appears to be correct
|
||||
if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")){
|
||||
assertEquals(outExp, outAct);
|
||||
} else {
|
||||
boolean eq = outExp.equalsWithEps(outAct, 0.1);
|
||||
assertTrue(eq);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSyntheticBidirectionalRNNGraph() throws Exception {
|
||||
|
||||
File f = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_100b6.bin");
|
||||
ComputationGraph net = ComputationGraph.load(f, true);
|
||||
|
||||
Bidirectional l0 = (Bidirectional) net.getLayer("rnn1").conf().getLayer();
|
||||
|
||||
LSTM l1 = (LSTM) l0.getFwd();
|
||||
assertEquals(16, l1.getNOut());
|
||||
assertEquals(new ActivationReLU(), l1.getActivationFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||
|
||||
LSTM l2 = (LSTM) l0.getBwd();
|
||||
assertEquals(16, l2.getNOut());
|
||||
assertEquals(new ActivationReLU(), l2.getActivationFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
||||
|
||||
Bidirectional l3 = (Bidirectional) net.getLayer("rnn2").conf().getLayer();
|
||||
|
||||
SimpleRnn l4 = (SimpleRnn) l3.getFwd();
|
||||
assertEquals(16, l4.getNOut());
|
||||
assertEquals(new ActivationReLU(), l4.getActivationFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l4));
|
||||
|
||||
SimpleRnn l5 = (SimpleRnn) l3.getBwd();
|
||||
assertEquals(16, l5.getNOut());
|
||||
assertEquals(new ActivationReLU(), l5.getActivationFn());
|
||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
||||
|
||||
MergeVertex mv = (MergeVertex) net.getVertex("concat");
|
||||
|
||||
GlobalPoolingLayer gpl = (GlobalPoolingLayer) net.getLayer("pooling").conf().getLayer();
|
||||
assertEquals(PoolingType.MAX, gpl.getPoolingType());
|
||||
assertArrayEquals(new int[]{2}, gpl.getPoolingDimensions());
|
||||
assertTrue(gpl.isCollapseDimensions());
|
||||
|
||||
OutputLayer outl = (OutputLayer) net.getLayer("out").conf().getLayer();
|
||||
assertEquals(3, outl.getNOut());
|
||||
assertEquals(new LossMCXENT(), outl.getLossFn());
|
||||
|
||||
INDArray outExp;
|
||||
File f2 = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_Output_100b6.bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
||||
outExp = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray in;
|
||||
File f3 = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_Input_100b6.bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
||||
in = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray outAct = net.output(in)[0];
|
||||
|
||||
assertEquals(outExp, outAct);
|
||||
}
|
||||
}
|
|
@ -41,7 +41,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
|||
|
||||
IGraph<String, String> graph = GraphLoader
|
||||
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
|
||||
System.out.println(graph);
|
||||
// System.out.println(graph);
|
||||
|
||||
assertEquals(graph.numVertices(), 7);
|
||||
int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}};
|
||||
|
@ -66,7 +66,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
|||
edgeLineProcessor, vertexFactory, 10, false);
|
||||
|
||||
|
||||
System.out.println(graph);
|
||||
// System.out.println(graph);
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
List<Edge<String>> edges = graph.getEdgesOut(i);
|
||||
|
@ -111,7 +111,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
|||
Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(),
|
||||
edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
|
||||
|
||||
System.out.println(graph);
|
||||
// System.out.println(graph);
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
List<Edge<String>> edges = graph.getEdgesOut(i);
|
||||
|
|
|
@ -71,7 +71,7 @@ public class TestGraphLoadingWeighted extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
System.out.println(graph);
|
||||
// System.out.println(graph);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -220,7 +220,7 @@ public class TestGraph extends BaseDL4JTest {
|
|||
sum += transitionProb[i][j];
|
||||
for (int j = 0; j < transitionProb[i].length; j++)
|
||||
transitionProb[i][j] /= sum;
|
||||
System.out.println(Arrays.toString(transitionProb[i]));
|
||||
// System.out.println(Arrays.toString(transitionProb[i]));
|
||||
}
|
||||
|
||||
//Check that transition probs are essentially correct (within bounds of random variation)
|
||||
|
|
|
@ -145,8 +145,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
|||
|
||||
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
||||
fail(msg);
|
||||
else
|
||||
System.out.println(msg);
|
||||
// else
|
||||
// System.out.println(msg);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -333,10 +333,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
|||
|
||||
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
||||
fail(msg);
|
||||
else
|
||||
System.out.println(msg);
|
||||
// else
|
||||
// System.out.println(msg);
|
||||
}
|
||||
System.out.println();
|
||||
// System.out.println();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
for (int i = 0; i < 7; i++) {
|
||||
INDArray vector = deepWalk.getVertexVector(i);
|
||||
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
||||
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||
}
|
||||
|
||||
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
|
||||
|
@ -77,11 +77,11 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
for (int t = 0; t < 5; t++) {
|
||||
iter.reset();
|
||||
deepWalk.fit(iter);
|
||||
System.out.println("--------------------");
|
||||
// System.out.println("--------------------");
|
||||
for (int i = 0; i < 7; i++) {
|
||||
INDArray vector = deepWalk.getVertexVector(i);
|
||||
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
||||
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
continue;
|
||||
|
||||
double sim = deepWalk.similarity(i, nearestTo);
|
||||
System.out.println(i + "\t" + nearestTo + "\t" + sim);
|
||||
// System.out.println(i + "\t" + nearestTo + "\t" + sim);
|
||||
assertTrue(sim <= minSimNearest);
|
||||
}
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
Graph<String, String> graph = GraphLoader
|
||||
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
|
||||
|
||||
System.out.println(graph);
|
||||
// System.out.println(graph);
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -229,11 +229,13 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
|
||||
//Calculate similarity(0,i)
|
||||
for (int i = 0; i < nVertices; i++) {
|
||||
System.out.println(deepWalk.similarity(0, i));
|
||||
// System.out.println(deepWalk.similarity(0, i));
|
||||
deepWalk.similarity(0, i);
|
||||
}
|
||||
|
||||
for (int i = 0; i < nVertices; i++)
|
||||
System.out.println(deepWalk.getVertexVector(i));
|
||||
// System.out.println(deepWalk.getVertexVector(i));
|
||||
deepWalk.getVertexVector(i);
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
|
|
|
@ -38,9 +38,11 @@ public class TestGraphHuffman extends BaseDL4JTest {
|
|||
|
||||
gh.buildTree(vertexDegrees);
|
||||
|
||||
for (int i = 0; i < 7; i++)
|
||||
System.out.println(i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
|
||||
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i)));
|
||||
for (int i = 0; i < 7; i++) {
|
||||
String s = i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
|
||||
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i));
|
||||
// System.out.println(s);
|
||||
}
|
||||
|
||||
int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5};
|
||||
for (int i = 0; i < vertexDegrees.length; i++) {
|
||||
|
|
|
@ -3,6 +3,7 @@ package org.deeplearning4j.util;
|
|||
import lombok.NonNull;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
|
@ -121,7 +122,7 @@ public class DL4JModelValidator {
|
|||
}
|
||||
|
||||
try{
|
||||
MultiLayerConfiguration.fromJson(config);
|
||||
ComputationGraphConfiguration.fromJson(config);
|
||||
} catch (Throwable t){
|
||||
return ValidationResult.builder()
|
||||
.formatType("ComputationGraph")
|
||||
|
|
|
@ -79,8 +79,9 @@ public class ParameterServerParallelWrapperTest extends BaseDL4JTest {
|
|||
model.init();
|
||||
|
||||
ParallelWrapper parameterServerParallelWrapper =
|
||||
new ParallelWrapper.Builder(model).trainerFactory(new ParameterServerTrainerContext())
|
||||
.workers(Runtime.getRuntime().availableProcessors())
|
||||
new ParallelWrapper.Builder(model)
|
||||
.workers(Math.min(4, Runtime.getRuntime().availableProcessors()))
|
||||
.trainerFactory(new ParameterServerTrainerContext())
|
||||
.reportScoreAfterAveraging(true).prefetchBuffer(3).build();
|
||||
parameterServerParallelWrapper.fit(mnistTrain);
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ public class SparkWord2VecTest extends BaseDL4JTest {
|
|||
public void call(ExportContainer<VocabWord> v) throws Exception {
|
||||
assertNotNull(v.getElement());
|
||||
assertNotNull(v.getArray());
|
||||
System.out.println(v.getElement() + " - " + v.getArray());
|
||||
// System.out.println(v.getElement() + " - " + v.getArray());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,7 +66,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
@ -119,7 +119,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
||||
|
@ -155,7 +155,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
||||
|
@ -198,7 +198,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
||||
|
@ -231,7 +231,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
|
|
@ -69,7 +69,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||
.setOutputs("0").build();
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
@ -120,7 +120,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MSE).build(), "in")
|
||||
.setOutputs("0").build();
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
||||
|
@ -158,7 +158,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||
.setOutputs("0").build();
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
||||
|
@ -203,7 +203,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||
.setOutputs("0").build();
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
||||
|
@ -238,7 +238,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||
.setOutputs("0").build();
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
|
|
@ -59,7 +59,7 @@ public class TestShuffleExamples extends BaseSparkTest {
|
|||
int totalExampleCount = 0;
|
||||
for (DataSet ds : shuffledList) {
|
||||
totalExampleCount += ds.getFeatures().length();
|
||||
System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
|
||||
// System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
|
||||
|
||||
assertEquals(ds.getFeatures(), ds.getLabels());
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ public class TestExport extends BaseSparkTest {
|
|||
for (File file : files) {
|
||||
if (!file.getPath().endsWith(".bin"))
|
||||
continue;
|
||||
System.out.println(file);
|
||||
// System.out.println(file);
|
||||
DataSet ds = new DataSet();
|
||||
ds.load(file);
|
||||
assertEquals(minibatchSize, ds.numExamples());
|
||||
|
@ -144,7 +144,7 @@ public class TestExport extends BaseSparkTest {
|
|||
for (File file : files) {
|
||||
if (!file.getPath().endsWith(".bin"))
|
||||
continue;
|
||||
System.out.println(file);
|
||||
// System.out.println(file);
|
||||
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
|
||||
ds.load(file);
|
||||
assertEquals(minibatchSize, ds.getFeatures(0).size(0));
|
||||
|
|
|
@ -92,9 +92,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
|||
|
||||
int[][] colorCountsByPartition = new int[3][2];
|
||||
for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) {
|
||||
System.out.println(val);
|
||||
// System.out.println(val);
|
||||
Integer partition = hbp.getPartition(val._1());
|
||||
System.out.println(partition);
|
||||
// System.out.println(partition);
|
||||
|
||||
if (val._2().equals("red"))
|
||||
colorCountsByPartition[partition][0] += 1;
|
||||
|
@ -102,9 +102,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
|||
colorCountsByPartition[partition][1] += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||
}
|
||||
// for (int i = 0; i < 3; i++) {
|
||||
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||
// }
|
||||
for (int i = 0; i < 3; i++) {
|
||||
// avg red per partition : 2.33
|
||||
assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4);
|
||||
|
@ -178,12 +178,12 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
|||
colorCountsByPartition[partition][1] += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < numPartitions; i++) {
|
||||
System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||
}
|
||||
|
||||
System.out.println("Ideal red # per partition: " + avgRed);
|
||||
System.out.println("Ideal blue # per partition: " + avgBlue);
|
||||
// for (int i = 0; i < numPartitions; i++) {
|
||||
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||
// }
|
||||
//
|
||||
// System.out.println("Ideal red # per partition: " + avgRed);
|
||||
// System.out.println("Ideal blue # per partition: " + avgBlue);
|
||||
|
||||
for (int i = 0; i < numPartitions; i++) {
|
||||
// avg red per partition : 2.33
|
||||
|
|
|
@ -115,7 +115,7 @@ public class TestSparkComputationGraph extends BaseSparkTest {
|
|||
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
|
||||
|
||||
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
|
||||
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(1)));
|
||||
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5)));
|
||||
|
||||
JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
|
||||
scg.fitMultiDataSet(rdd);
|
||||
|
|
|
@ -31,8 +31,11 @@ import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
|
@ -45,8 +48,24 @@ import static org.junit.Assert.assertTrue;
|
|||
@Slf4j
|
||||
public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
||||
|
||||
@Test(timeout = 120000L)
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 120000L;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDefaultFPDataType() {
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationSimple() throws Exception {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
for( int evalWorkers : new int[]{1, 4, 8}) {
|
||||
//Simple test to validate DL4J issue 4099 is fixed...
|
||||
|
@ -75,18 +94,18 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
|||
//----------------------------------
|
||||
//Create network configuration and conduct network training
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(12345)
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.activation(Activation.LEAKYRELU)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.updater(new Nesterovs(0.02, 0.9))
|
||||
.l2(1e-4)
|
||||
.updater(new Adam(1e-3))
|
||||
.l2(1e-5)
|
||||
.list()
|
||||
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
|
||||
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
|
||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
|
||||
|
||||
.build();
|
||||
|
||||
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options
|
||||
|
|
|
@ -333,15 +333,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
|||
sparkNet.fit(rdd);
|
||||
}
|
||||
|
||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
sparkNet.getSparkTrainingStats().statsAsString();
|
||||
|
||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||
|
||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
System.out.println("Initial (Spark) params: "
|
||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
// System.out.println("Initial (Spark) params: "
|
||||
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
assertEquals(initialParams, initialSparkParams);
|
||||
assertNotEquals(initialParams, finalParams);
|
||||
assertEquals(finalParams, finalSparkParams);
|
||||
|
@ -405,15 +406,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
|||
sparkNet.fit(rdd);
|
||||
}
|
||||
|
||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
sparkNet.getSparkTrainingStats().statsAsString();
|
||||
|
||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||
|
||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
System.out.println("Initial (Spark) params: "
|
||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
// System.out.println("Initial (Spark) params: "
|
||||
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
||||
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
||||
|
||||
|
@ -478,18 +480,19 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
|||
sparkNet.fit(rdd);
|
||||
}
|
||||
|
||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
sparkNet.getSparkTrainingStats().statsAsString();
|
||||
|
||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||
// executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
|
||||
|
||||
float[] fp = finalParams.data().asFloat();
|
||||
float[] fps = finalSparkParams.data().asFloat();
|
||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
System.out.println("Initial (Spark) params: "
|
||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
System.out.println("Final (Local) params: " + Arrays.toString(fp));
|
||||
System.out.println("Final (Spark) params: " + Arrays.toString(fps));
|
||||
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
// System.out.println("Initial (Spark) params: "
|
||||
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
// System.out.println("Final (Local) params: " + Arrays.toString(fp));
|
||||
// System.out.println("Final (Spark) params: " + Arrays.toString(fps));
|
||||
|
||||
assertEquals(initialParams, initialSparkParams);
|
||||
assertNotEquals(initialParams, finalParams);
|
||||
|
@ -551,14 +554,15 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
|||
sparkNet.fit(rdd);
|
||||
}
|
||||
|
||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
sparkNet.getSparkTrainingStats().statsAsString();
|
||||
|
||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||
|
||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
// System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
||||
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ public class TestJsonYaml {
|
|||
String json = tm.toJson();
|
||||
String yaml = tm.toYaml();
|
||||
|
||||
System.out.println(json);
|
||||
// System.out.println(json);
|
||||
|
||||
TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json);
|
||||
TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml);
|
||||
|
|
|
@ -389,7 +389,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
||||
for (EventStats e : workerFitStats) {
|
||||
ExampleCountEventStats eces = (ExampleCountEventStats) e;
|
||||
System.out.println(eces.getTotalExampleCount());
|
||||
// System.out.println(eces.getTotalExampleCount());
|
||||
}
|
||||
|
||||
for (EventStats e : workerFitStats) {
|
||||
|
@ -457,7 +457,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
assertNotEquals(paramsBefore, paramsAfter);
|
||||
|
||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||
System.out.println(stats.statsAsString());
|
||||
// System.out.println(stats.statsAsString());
|
||||
stats.statsAsString();
|
||||
|
||||
sparkNet.getTrainingMaster().deleteTempFiles(sc);
|
||||
}
|
||||
|
@ -483,7 +484,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
i++;
|
||||
}
|
||||
|
||||
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||
|
||||
|
||||
|
||||
|
@ -527,7 +528,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||
|
||||
//Expect
|
||||
System.out.println(stats.statsAsString());
|
||||
// System.out.println(stats.statsAsString());
|
||||
stats.statsAsString();
|
||||
assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size());
|
||||
|
||||
List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
||||
|
@ -566,8 +568,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
i++;
|
||||
}
|
||||
|
||||
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||
System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
|
||||
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||
// System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
|
||||
|
||||
|
||||
|
||||
|
@ -610,7 +612,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
assertNotEquals(paramsBefore, paramsAfter);
|
||||
|
||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||
System.out.println(stats.statsAsString());
|
||||
// System.out.println(stats.statsAsString());
|
||||
stats.statsAsString();
|
||||
|
||||
//Same thing, buf for MultiDataSet objects:
|
||||
config = new Configuration();
|
||||
|
@ -631,7 +634,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
assertNotEquals(paramsBefore, paramsAfter);
|
||||
|
||||
stats = sparkNet.getSparkTrainingStats();
|
||||
System.out.println(stats.statsAsString());
|
||||
// System.out.println(stats.statsAsString());
|
||||
stats.statsAsString();
|
||||
}
|
||||
|
||||
|
||||
|
@ -730,13 +734,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
.build();
|
||||
|
||||
for (int avgFreq : new int[] {1, 5, 10}) {
|
||||
System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||
// System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(),
|
||||
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
||||
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
||||
.repartionData(Repartition.Always).build());
|
||||
|
||||
sparkNet.setListeners(new ScoreIterationListener(1));
|
||||
sparkNet.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
|
||||
|
||||
|
@ -778,13 +782,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
.setOutputs("1").build();
|
||||
|
||||
for (int avgFreq : new int[] {1, 5, 10}) {
|
||||
System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||
// System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(),
|
||||
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
||||
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
||||
.repartionData(Repartition.Always).build());
|
||||
|
||||
sparkNet.setListeners(new ScoreIterationListener(1));
|
||||
sparkNet.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> rdd = sc.parallelize(list);
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
|
|||
expectedStatNames.addAll(c);
|
||||
}
|
||||
|
||||
System.out.println(expectedStatNames);
|
||||
// System.out.println(expectedStatNames);
|
||||
|
||||
|
||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||
|
@ -119,7 +119,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
|
|||
}
|
||||
|
||||
String statsAsString = stats.statsAsString();
|
||||
System.out.println(statsAsString);
|
||||
// System.out.println(statsAsString);
|
||||
assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat
|
||||
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ public class TestTimeSource {
|
|||
long systemTime = System.currentTimeMillis();
|
||||
long ntpTime = timeSource.currentTimeMillis();
|
||||
long offset = ntpTime - systemTime;
|
||||
System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||
// System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||
Thread.sleep(500);
|
||||
}
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ public class TestTimeSource {
|
|||
long systemTime = System.currentTimeMillis();
|
||||
long ntpTime = timeSource.currentTimeMillis();
|
||||
long offset = ntpTime - systemTime;
|
||||
System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||
// System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||
assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next
|
||||
Thread.sleep(500);
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ public class TestListeners extends BaseSparkTest {
|
|||
net.fit(rdd);
|
||||
|
||||
List<String> sessions = ss.listSessionIDs();
|
||||
System.out.println("Sessions: " + sessions);
|
||||
// System.out.println("Sessions: " + sessions);
|
||||
assertEquals(1, sessions.size());
|
||||
|
||||
String sid = sessions.get(0);
|
||||
|
@ -95,15 +95,15 @@ public class TestListeners extends BaseSparkTest {
|
|||
List<String> typeIDs = ss.listTypeIDsForSession(sid);
|
||||
List<String> workers = ss.listWorkerIDsForSession(sid);
|
||||
|
||||
System.out.println(sid + "\t" + typeIDs + "\t" + workers);
|
||||
// System.out.println(sid + "\t" + typeIDs + "\t" + workers);
|
||||
|
||||
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
|
||||
System.out.println(lastUpdates);
|
||||
// System.out.println(lastUpdates);
|
||||
|
||||
System.out.println("Static info:");
|
||||
// System.out.println("Static info:");
|
||||
for (String wid : workers) {
|
||||
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
|
||||
System.out.println(sid + "\t" + wid);
|
||||
// System.out.println(sid + "\t" + wid);
|
||||
}
|
||||
|
||||
assertEquals(1, typeIDs.size());
|
||||
|
|
|
@ -63,7 +63,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
assertEquals(10, rdd2.partitions().size());
|
||||
for (int i = 0; i < 10; i++) {
|
||||
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
||||
System.out.println("Partition " + i + " size: " + partition.size());
|
||||
// System.out.println("Partition " + i + " size: " + partition.size());
|
||||
assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
|
||||
}
|
||||
}
|
||||
|
@ -170,7 +170,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
|
||||
List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
||||
|
||||
System.out.println(partitionCounts);
|
||||
// System.out.println(partitionCounts);
|
||||
|
||||
List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList(
|
||||
new Tuple2<>(0,29),
|
||||
|
@ -185,7 +185,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
|
||||
JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112);
|
||||
List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
||||
System.out.println(partitionCountsAfter);
|
||||
// System.out.println(partitionCountsAfter);
|
||||
|
||||
for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){
|
||||
assertEquals(2, (int)t2._2());
|
||||
|
@ -219,8 +219,8 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
}
|
||||
}
|
||||
|
||||
System.out.println("min: " + min + "\t@\t" + minIdx);
|
||||
System.out.println("max: " + max + "\t@\t" + maxIdx);
|
||||
// System.out.println("min: " + min + "\t@\t" + minIdx);
|
||||
// System.out.println("max: " + max + "\t@\t" + maxIdx);
|
||||
|
||||
assertEquals(1, min);
|
||||
assertEquals(2, max);
|
||||
|
@ -244,7 +244,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
||||
System.out.println("Partition " + i + " size: " + partition.size());
|
||||
// System.out.println("Partition " + i + " size: " + partition.size());
|
||||
assertTrue(partition.size() >= 90 && partition.size() <= 110);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ project(mkldnn-download NONE)
|
|||
include(ExternalProject)
|
||||
ExternalProject_Add(mkldnn
|
||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||
GIT_TAG v1.1.3
|
||||
GIT_TAG v1.2
|
||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||
CONFIGURE_COMMAND ""
|
||||
|
|
|
@ -999,14 +999,14 @@ namespace nd4j {
|
|||
* set new order and shape in case of suitable array length (in-place operation)
|
||||
* order - order to set
|
||||
* shape - shape to set
|
||||
*
|
||||
* copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping
|
||||
* if there was permute applied before or there are weird strides, then new buffer is allocated for array
|
||||
*/
|
||||
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape);
|
||||
bool reshapei(const char order, const std::vector<Nd4jLong>& shape);
|
||||
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||
bool reshapei(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||
|
||||
bool reshapei(const std::initializer_list<Nd4jLong>& shape);
|
||||
bool reshapei(const std::vector<Nd4jLong>& shape);
|
||||
bool reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||
bool reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||
|
||||
/**
|
||||
* creates new array with corresponding order and shape, new array will point on _buffer of this array
|
||||
|
@ -1015,8 +1015,8 @@ namespace nd4j {
|
|||
*
|
||||
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
|
||||
*/
|
||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const &;
|
||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) &&;
|
||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) const &;
|
||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) &&;
|
||||
|
||||
/**
|
||||
* calculate strides and set given order
|
||||
|
@ -1493,7 +1493,7 @@ namespace nd4j {
|
|||
* @return
|
||||
*/
|
||||
bool isS() const;
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> asVectorT();
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ ND4J_EXPORT std::u32string NDArray::e(const Nd4jLong i) const;
|
|||
////////////////////////////////////////////////////////////////////////
|
||||
// copy constructor
|
||||
NDArray::NDArray(const NDArray& other) {
|
||||
|
||||
|
||||
_context = other._context;
|
||||
_offset = 0;
|
||||
|
||||
|
@ -308,7 +308,7 @@ NDArray::NDArray(const std::u16string& u16string, nd4j::DataType dtype, nd4j::La
|
|||
if (!unicode::isStringValidU16(u16string.data(), u16string.data() + u16string.size())) {
|
||||
throw std::invalid_argument("NDArray::NDArray: invalid character in input string");
|
||||
}
|
||||
|
||||
|
||||
// one word that is why used 1
|
||||
Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1);
|
||||
|
||||
|
@ -435,11 +435,11 @@ NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchConte
|
|||
_offset = 0;
|
||||
|
||||
setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype));
|
||||
|
||||
|
||||
memcpy(bufferAsT<int8_t>(), &offsets[0], 2 * sizeof(Nd4jLong));
|
||||
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
|
||||
if (dtype == DataType::UTF8) {
|
||||
memcpy(data, str.data(), str.size());
|
||||
}
|
||||
|
@ -456,13 +456,13 @@ NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchConte
|
|||
/////////////////////////////////////////////////////////////////////////
|
||||
// constructors for vector of strings
|
||||
NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& string, const nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
|
||||
|
||||
if (!DataTypeUtils::isS(dataType))
|
||||
throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used");
|
||||
|
||||
if (shape::prodLong(shape.data(), shape.size()) != string.size())
|
||||
throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array");
|
||||
|
||||
|
||||
for (const auto& str : string) {
|
||||
if (!unicode::isStringValidU8(str, str + std::char_traits<char>::length(str)) ) {
|
||||
throw std::invalid_argument("NDArray::NDArray: invalid character in input string");
|
||||
|
@ -497,11 +497,11 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
|||
setAttached(context->getWorkspace() != nullptr);
|
||||
|
||||
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto cdata = data + offsets[e];
|
||||
if (dataType == DataType::UTF16) {
|
||||
unicode::utf8to16(string[e], cdata, std::char_traits<char>::length(string[e]));
|
||||
|
@ -568,7 +568,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::stri
|
|||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto cdata = data + offsets[e];
|
||||
if (dataType == DataType::UTF16) {
|
||||
unicode::utf8to16(string[e].data(), cdata, string[e].size());
|
||||
|
@ -631,11 +631,11 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16s
|
|||
setAttached(context->getWorkspace() != nullptr);
|
||||
|
||||
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto cdata = data + offsets[e];
|
||||
if (dtype == DataType::UTF16) {
|
||||
memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t));
|
||||
|
@ -699,9 +699,9 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
|||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto cdata = data + offsets[e];
|
||||
if (dtype == DataType::UTF16) {
|
||||
memcpy(cdata, string[e], std::char_traits<char16_t>::length(string[e]) * sizeof(uint16_t));
|
||||
|
@ -715,7 +715,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
|||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||
|
||||
|
||||
tickWriteHost();
|
||||
syncToDevice();
|
||||
}
|
||||
|
@ -764,10 +764,10 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
|
|||
|
||||
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto cdata = data + offsets[e];
|
||||
if (dtype == DataType::UTF16) {
|
||||
unicode::utf32to16(string[e].data(), cdata, string[e].size());
|
||||
|
@ -781,7 +781,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
|
|||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||
|
||||
|
||||
tickWriteHost();
|
||||
syncToDevice();
|
||||
}
|
||||
|
@ -831,9 +831,9 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
|||
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto cdata = data + offsets[e];
|
||||
if (dtype == DataType::UTF16) {
|
||||
unicode::utf32to16(string[e], cdata, std::char_traits<char32_t>::length(string[e]));
|
||||
|
@ -847,7 +847,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
|||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||
|
||||
|
||||
tickWriteHost();
|
||||
syncToDevice();
|
||||
}
|
||||
|
@ -887,8 +887,8 @@ bool NDArray::isC() const {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
bool NDArray::isS() const {
|
||||
return (dataType() == DataType::UTF8 ||
|
||||
dataType() == DataType::UTF16 ||
|
||||
return (dataType() == DataType::UTF8 ||
|
||||
dataType() == DataType::UTF16 ||
|
||||
dataType() == DataType::UTF32);
|
||||
}
|
||||
|
||||
|
@ -1197,8 +1197,8 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
|
|||
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
|
||||
}
|
||||
|
||||
// memcpy is allowed only for same order && same ews (being equal to 1)
|
||||
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
||||
// memcpy is allowed only for same order c && same ews (being equal to 1)
|
||||
if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
||||
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
||||
else {
|
||||
NDArray::prepareSpecialUse({this}, {&other});
|
||||
|
@ -1569,20 +1569,25 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector<int>& dimensions) cons
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::printShapeInfo(const char * msg) const {
|
||||
//shape::printShapeInfo(_shapeInfo);
|
||||
if (msg == nullptr)
|
||||
shape::printShapeInfoLinear(_shapeInfo);
|
||||
else {
|
||||
int rank = shape::rank(_shapeInfo);
|
||||
int lim = shape::shapeInfoLength(rank);
|
||||
printf("%s: [", msg);
|
||||
for (int i = 0; i < shape::shapeInfoLength(rank); i++) {
|
||||
printf("%lld", (long long) _shapeInfo[i]);
|
||||
if (i < lim - 1)
|
||||
printf(", ");
|
||||
}
|
||||
printf("]\n");
|
||||
|
||||
int rank = shape::rank(_shapeInfo);
|
||||
int lim = shape::shapeInfoLength(rank);
|
||||
|
||||
if(msg != nullptr)
|
||||
printf("shapeInfo %s: [", msg);
|
||||
else
|
||||
printf("shapeInfo: [");
|
||||
|
||||
printf("%i, ", rank);
|
||||
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
|
||||
if(i == rank + 1)
|
||||
printf(" ");
|
||||
printf("%lld,", _shapeInfo[i]);
|
||||
}
|
||||
printf(" %lld,", shape::type(_shapeInfo));
|
||||
printf("%lld,", shape::elementWiseStride(_shapeInfo));
|
||||
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
|
@ -1624,7 +1629,7 @@ void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) cons
|
|||
if (e < limit - 1)
|
||||
printf(", ");
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (this->isS()) {
|
||||
// todo do we need this print offsets
|
||||
/*
|
||||
|
@ -1773,7 +1778,7 @@ void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const {
|
|||
printf("%s\n", this->e<bool>(0)?"true":"false");
|
||||
}
|
||||
else if (this->isS()) {
|
||||
// todo do we need this
|
||||
// todo do we need this
|
||||
// printf("\"%lld\"\n", this->getOffset(e));
|
||||
printf("\"%s\"\n", this->e<std::string>(0).c_str());
|
||||
}
|
||||
|
@ -1855,19 +1860,19 @@ void NDArray::updateStrides(const char order) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// set new order and shape in case of suitable array length
|
||||
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape) {
|
||||
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||
std::vector<Nd4jLong> vShape(shape);
|
||||
return reshapei(order, vShape);
|
||||
return reshapei(order, vShape, copyToNewBuff);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape) {
|
||||
return reshapei('c', shape);
|
||||
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||
return reshapei(ordering(), shape, copyToNewBuff);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape) {
|
||||
return reshapei('c', shape);
|
||||
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||
return reshapei(ordering(), shape, copyToNewBuff);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1918,18 +1923,18 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
|
||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const & {
|
||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) const & {
|
||||
|
||||
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
|
||||
newArr.reshapei(order, shape);
|
||||
newArr.reshapei(order, shape, copyToNewBuff);
|
||||
|
||||
return newArr;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) && {
|
||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) && {
|
||||
|
||||
this->reshapei(order, shape);
|
||||
this->reshapei(order, shape, copyToNewBuff);
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
|
@ -1971,7 +1976,7 @@ bool NDArray::permutei(const std::initializer_list<int>& dimensions) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
bool NDArray::permutei(const std::vector<int>& dimensions) {
|
||||
return permutei(dimensions.data(), dimensions.size());
|
||||
return permutei(dimensions.data(), rankOf());
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1993,7 +1998,7 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
|
|||
for (int e = 0; e < dimensions.size(); e++)
|
||||
ivec[e] = dimensions[e];
|
||||
|
||||
return permutei(ivec.data(), ivec.size());
|
||||
return permutei(ivec.data(), rankOf());
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2029,9 +2034,8 @@ NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::permute(const std::vector<int>& dimensions) const &{
|
||||
auto data = dimensions.data();
|
||||
auto size = dimensions.size();
|
||||
return permute(data, size);
|
||||
|
||||
return permute(dimensions.data(), rankOf());
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2043,7 +2047,8 @@ NDArray NDArray::permute(const std::vector<int>& dimensions) && {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & {
|
||||
return permute(dimensions.data(), dimensions.size());
|
||||
|
||||
return permute(dimensions.data(), rankOf());
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2106,12 +2111,12 @@ void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& targe
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const {
|
||||
permute(dimensions.data(), dimensions.size(), target);
|
||||
permute(dimensions.data(), rankOf(), target);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const {
|
||||
permute(dimensions.data(), dimensions.size(), target);
|
||||
permute(dimensions.data(), rankOf(), target);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2280,7 +2285,7 @@ template <typename T>
|
|||
NDArray NDArray::asT() const{
|
||||
|
||||
auto result = isScalar() ? NDArray('c', {}, std::vector<double>{0.}, DataTypeUtils::fromT<T>(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
|
||||
|
||||
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.getSpecialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr);
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
@ -2298,15 +2303,15 @@ NDArray NDArray::asS() const {
|
|||
|
||||
auto dtype = DataTypeUtils::fromT<T>();
|
||||
|
||||
if (!(DataTypeUtils::isS(dtype)))
|
||||
if (!(DataTypeUtils::isS(dtype)))
|
||||
throw std::invalid_argument("NDArray::asS: invalid DataType used");
|
||||
|
||||
|
||||
if (dtype == dataType()) {
|
||||
|
||||
|
||||
Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf());
|
||||
const auto nInputoffsets = bufferAsT<Nd4jLong>();
|
||||
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(offsetsLength + nInputoffsets[lengthOf()], dtype, getContext()->getWorkspace(), true);
|
||||
|
||||
|
||||
NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext());
|
||||
res.setAttached(getContext()->getWorkspace() != nullptr);
|
||||
|
||||
|
@ -2319,7 +2324,7 @@ NDArray NDArray::asS() const {
|
|||
registerPrimaryUse({ &res }, { this });
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf());
|
||||
|
||||
std::vector<Nd4jLong> offsets(lengthOf() + 1);
|
||||
|
@ -2353,7 +2358,7 @@ NDArray NDArray::asS() const {
|
|||
|
||||
NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext());
|
||||
res.setAttached(getContext()->getWorkspace() != nullptr);
|
||||
|
||||
|
||||
preparePrimaryUse({ &res }, { this });
|
||||
|
||||
memcpy(res.bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
@ -2362,7 +2367,7 @@ NDArray NDArray::asS() const {
|
|||
const auto inData = bufferAsT<int8_t>() + offsetsLength;
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (int e = start; e < stop; e += increment) {
|
||||
for (int e = start; e < stop; e++) {
|
||||
auto cdata = outData + offsets[e];
|
||||
auto end = nInputoffsets[e + 1];
|
||||
auto idata = inData + nInputoffsets[e];
|
||||
|
@ -2403,7 +2408,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asS, () const, LIBND
|
|||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::asT(DataType dtype) const {
|
||||
|
||||
|
||||
if (isS() && !DataTypeUtils::isS(dtype))
|
||||
throw std::runtime_error("NDArray::asT: you can't use this method on String array with not string DataType!");
|
||||
|
||||
|
@ -3221,7 +3226,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LI
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// set new order and shape in case of suitable array length
|
||||
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
||||
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape, const bool copyToNewBuff) {
|
||||
|
||||
// check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
|
||||
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
|
||||
|
@ -3293,19 +3298,15 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
|||
Nd4jLong *shapeInfoNew;
|
||||
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
||||
|
||||
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
|
||||
bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew);
|
||||
|
||||
// we can do this only if there was no permute applied, or there are no weird strides
|
||||
if (canReshape) {
|
||||
if(ordering() == 'c' && order == 'f')
|
||||
throw std::invalid_argument("NDArray::reshapei(order, shape): in case of reshapeC it doesn't make sense to reshape from c order to f order !");
|
||||
|
||||
shape::setEws(shapeInfoNew, arrLength);
|
||||
setShapeInfo(shapeInfoNew);
|
||||
}
|
||||
else {
|
||||
NDArray temp(order, shape, dataType(), getContext());
|
||||
this->applyTransform(transform::Assign, temp, nullptr);
|
||||
if(copyToNewBuff)
|
||||
this->applyTransform(transform::Assign, temp, nullptr);
|
||||
*this = std::move(temp);
|
||||
}
|
||||
|
||||
|
@ -3463,9 +3464,9 @@ NDArray NDArray::dup(const char newOrder) const {
|
|||
if (isS()) {
|
||||
if (dataType() == DataType::UTF8) {
|
||||
std::vector<std::string> strings(lengthOf());
|
||||
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
strings[i] = std::move(this->e<std::string>(i));
|
||||
}
|
||||
};
|
||||
|
@ -3478,7 +3479,7 @@ NDArray NDArray::dup(const char newOrder) const {
|
|||
std::vector<std::u16string> strings(lengthOf());
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
strings[i] = std::move(this->e<std::u16string>(i));
|
||||
}
|
||||
};
|
||||
|
@ -3490,7 +3491,7 @@ NDArray NDArray::dup(const char newOrder) const {
|
|||
|
||||
std::vector<std::u32string> strings(lengthOf());
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
strings[i] = std::move(this->e<std::u32string>(i));
|
||||
}
|
||||
};
|
||||
|
@ -3521,7 +3522,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
|
|||
|
||||
if (isS()) {
|
||||
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
|
||||
|
||||
|
||||
if (dataType() == DataType::UTF8) {
|
||||
for (int e = 0; e < this->lengthOf(); e++) {
|
||||
auto s1 = this->e<std::string>(e);
|
||||
|
@ -3585,7 +3586,7 @@ std::string NDArray::e(const Nd4jLong i) const {
|
|||
if (i == lengthOf())
|
||||
throw std::runtime_error("Can't get std::string for index out of range");
|
||||
|
||||
|
||||
|
||||
if (this->dataType() == DataType::UTF16) {
|
||||
auto u16 = this->e<std::u16string>(i);
|
||||
std::string s;
|
||||
|
@ -4846,7 +4847,7 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
|
|||
auto shapeOf = shape::shapeOf(newShapeInfo);
|
||||
auto stridesOf = shape::stride(newShapeInfo);
|
||||
|
||||
Nd4jLong offset(0), subArrLen(1);
|
||||
Nd4jLong offset = 0;
|
||||
int n(isStrided ? 3 : 2), first, last, stride;
|
||||
|
||||
for (int d = rank - 1; d >= 0; --d) {
|
||||
|
@ -4863,29 +4864,31 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
|
|||
if(shapeOf[d] != 1)
|
||||
stridesOf[d] *= stride;
|
||||
}
|
||||
}
|
||||
|
||||
subArrLen *= shapeOf[d];
|
||||
Nd4jLong *newShapeInfo2 = newShapeInfo;
|
||||
|
||||
if(!keepUnitiesInShape) {
|
||||
|
||||
std::vector<int> dimsWithUnities;
|
||||
|
||||
for (uint d = 0; d < rank; ++d)
|
||||
if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1)
|
||||
dimsWithUnities.push_back(d);
|
||||
|
||||
if(!dimsWithUnities.empty())
|
||||
newShapeInfo2 = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace());
|
||||
}
|
||||
|
||||
// check if there is possibility to set ews = 1
|
||||
shape::setEws(newShapeInfo, subArrLen);
|
||||
shape::checkStridesEwsAndOrder(newShapeInfo2);
|
||||
|
||||
NDArray result(_buffer, ShapeDescriptor(newShapeInfo), getContext(), offset + getBufferOffset());
|
||||
NDArray result(_buffer, ShapeDescriptor(newShapeInfo2), getContext(), offset + getBufferOffset());
|
||||
result._isView = true;
|
||||
|
||||
if(!keepUnitiesInShape) {
|
||||
const int coeff = isStrided ? 3 : 2;
|
||||
std::vector<Nd4jLong> nonUnitDims;
|
||||
|
||||
for (int d = 0; d < rank; ++d)
|
||||
if(!(idx[coeff*d] != idx[coeff*d+1] && newShapeInfo[d+1] == 1))
|
||||
nonUnitDims.push_back(newShapeInfo[d+1]);
|
||||
|
||||
if(nonUnitDims.size() != rank)
|
||||
result.reshapei(nonUnitDims);
|
||||
}
|
||||
|
||||
RELEASE(newShapeInfo, getContext()->getWorkspace());
|
||||
if(newShapeInfo != newShapeInfo2)
|
||||
RELEASE(newShapeInfo2, getContext()->getWorkspace());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -179,7 +179,7 @@ namespace graph {
|
|||
nd4j_debug("Embedded graph execution finished. %i variable(s) migrated\n", cnt);
|
||||
|
||||
} else if (node->hasCustomOp()) {
|
||||
// if we have something to execute - lets just execute it.
|
||||
// now, if we have something to execute - lets just execute it.
|
||||
auto status = node->getCustomOp()->execute(&context);
|
||||
if (status != ND4J_STATUS_OK)
|
||||
return status;
|
||||
|
@ -494,8 +494,10 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
nd4j::memory::MemoryRegistrator::getInstance()->setGraphMemoryFootprintIfGreater(h, m);
|
||||
}
|
||||
|
||||
if (tempFlow)
|
||||
if (tempFlow) {
|
||||
delete flowPath;
|
||||
__variableSpace->setFlowPath(nullptr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -98,7 +98,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
Nd4jLong coords[MAX_RANK];
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
shape::index2coords(i, target.getShapeInfo(), coords);
|
||||
const auto zOffset = shape::getOffset(target.getShapeInfo(), coords);
|
||||
|
||||
|
@ -152,7 +152,7 @@ static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) {
|
|||
auto y = reinterpret_cast<T *>(yBuffer);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto temp = x[i];
|
||||
x[i] = y[i];
|
||||
y[i] = temp;
|
||||
|
@ -266,7 +266,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
|||
if(result.ordering() == 'c') { // ews == 1 always here
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
|
||||
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES);
|
||||
}
|
||||
|
@ -277,7 +277,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
|||
else {
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto xOffset = result.getOffset(i);
|
||||
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
|
||||
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES);
|
||||
|
@ -377,7 +377,7 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
|
|||
// loop through input array
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
Nd4jLong coords[MAX_RANK];
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
shape::index2coords(i, output.getShapeInfo(), coords);
|
||||
|
||||
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords);
|
||||
|
|
|
@ -22,7 +22,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
|||
if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment)
|
||||
for (auto e = start; e < stop; e++)
|
||||
z[e] = func(f[e], s[e], t[e]);
|
||||
};
|
||||
|
||||
|
@ -31,7 +31,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
|||
if (f == z) {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto tOffset = this->getOffset(e);
|
||||
auto uOffset = second.getOffset(e);
|
||||
auto vOffset = third.getOffset(e);
|
||||
|
@ -44,7 +44,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
|||
} else {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto tOffset = this->getOffset(e);
|
||||
auto uOffset = second.getOffset(e);
|
||||
auto vOffset = third.getOffset(e);
|
||||
|
@ -93,7 +93,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
|
|||
if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment)
|
||||
for (auto e = start; e < stop; e++)
|
||||
z[e] = func(f[e], s[e]);
|
||||
};
|
||||
|
||||
|
@ -102,7 +102,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
|
|||
if (f == z) {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other.getOffset(e);
|
||||
|
||||
|
@ -114,7 +114,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
|
|||
} else {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other.getOffset(e);
|
||||
auto zOffset = target.getOffset(e);
|
||||
|
@ -156,7 +156,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
|
|||
if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment)
|
||||
for (auto e = start; e < stop; e++)
|
||||
z[e] = func(f[e]);
|
||||
};
|
||||
|
||||
|
@ -165,7 +165,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
|
|||
if (f == z) {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto xOffset = this->getOffset(e);
|
||||
|
||||
f[xOffset] = func(f[xOffset]);
|
||||
|
@ -176,7 +176,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
|
|||
} else {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto zOffset = target.getOffset(e);
|
||||
|
||||
|
@ -217,7 +217,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
|
|||
if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment)
|
||||
for (auto e = start; e < stop; e++)
|
||||
z[e] = func(e, f[e]);
|
||||
};
|
||||
|
||||
|
@ -226,7 +226,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
|
|||
if (f == z) {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto xOffset = this->getOffset(e);
|
||||
|
||||
f[xOffset] = func(e, f[xOffset]);
|
||||
|
@ -237,7 +237,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
|
|||
} else {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto zOffset = target.getOffset(e);
|
||||
|
||||
|
@ -283,7 +283,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
|
|||
if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment)
|
||||
for (auto e = start; e < stop; e++)
|
||||
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
||||
};
|
||||
|
||||
|
@ -292,7 +292,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
|
|||
if (f == z) {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other.getOffset(e);
|
||||
|
||||
|
@ -304,7 +304,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
|
|||
} else {
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other.getOffset(e);
|
||||
auto zOffset = target.getOffset(e);
|
||||
|
|
|
@ -163,15 +163,44 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
|
|||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
#else
|
||||
|
||||
auto loopKind = nd4j::LoopKind::deduceKindOfLoopBroadcast(hXShapeInfo, hYShapeInfo, hZShapeInfo);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), LIBND4J_TYPES);
|
||||
};
|
||||
|
||||
auto xLen = shape::length(hXShapeInfo);
|
||||
auto yLen = shape::length(hYShapeInfo);
|
||||
auto numTads = xLen / yLen;
|
||||
Nd4jLong numTads = 0;
|
||||
|
||||
switch (loopKind) {
|
||||
case nd4j::LoopKind::BROADCAST_SCALAR_X: {
|
||||
numTads = shape::length(hXShapeInfo);
|
||||
}
|
||||
break;
|
||||
case nd4j::LoopKind::BROADCAST_SCALAR_Y: {
|
||||
numTads = shape::length(hYShapeInfo);
|
||||
}
|
||||
break;
|
||||
case nd4j::LoopKind::BROADCAST_3D: {
|
||||
numTads = shape::sizeAt(hZShapeInfo, 0);
|
||||
}
|
||||
break;
|
||||
case nd4j::LoopKind::BROADCAST_4D: {
|
||||
numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1);
|
||||
}
|
||||
break;
|
||||
case nd4j::LoopKind::BROADCAST_5D: {
|
||||
numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1);
|
||||
}
|
||||
break;
|
||||
default: {
|
||||
auto xLen = shape::length(hXShapeInfo);
|
||||
auto yLen = shape::length(hYShapeInfo);
|
||||
numTads = xLen / yLen;
|
||||
}
|
||||
}
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -1291,7 +1291,7 @@ void pullRowsGeneric(void *vx,
|
|||
_threads = nd4j::math::nd4j_min<int>(_threads, nd4j::Environment::getInstance()->maxThreads());
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto idx = start; idx < stop; idx += increment) {
|
||||
for (auto idx = start; idx < stop; idx++) {
|
||||
auto xTadOffsetForBlock = tadOffsets[indexes[idx]];
|
||||
auto zTadOffsetForBlock = zTadOffsets[idx];
|
||||
|
||||
|
@ -1356,7 +1356,7 @@ void tearGeneric(void *vx,
|
|||
auto numTads = shape::length(hXShapeInfo) / tadLength;
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto hZ = reinterpret_cast<T *>(targets[i]);
|
||||
auto s = hX + tadOffsets[i];
|
||||
|
||||
|
@ -1478,7 +1478,7 @@ void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZS
|
|||
auto dZ = reinterpret_cast<T **>(dz);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto f = start; f < stop; f += increment) {
|
||||
for (auto f = start; f < stop; f++) {
|
||||
auto hX = reinterpret_cast<T *>(dX[f]);
|
||||
//auto hZ = reinterpret_cast<T *>(dZ[f]);
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ namespace nd4j {
|
|||
TypeCast::convertGeneric<T2, T>(nullptr, tmp, length, buffer);
|
||||
#else
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment)
|
||||
for (auto e = start; e < stop; e++)
|
||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||
};
|
||||
|
||||
|
@ -110,7 +110,7 @@ namespace nd4j {
|
|||
TypeCast::convertGeneric<float, T>(nullptr, tmp, length, buffer);
|
||||
#else
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment)
|
||||
for (auto e = start; e < stop; e++)
|
||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||
};
|
||||
|
||||
|
@ -138,7 +138,7 @@ namespace nd4j {
|
|||
|
||||
#else
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment)
|
||||
for (auto e = start; e < stop; e++)
|
||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||
};
|
||||
|
||||
|
@ -164,7 +164,7 @@ namespace nd4j {
|
|||
TypeCast::convertGeneric<float16, T>(nullptr, tmp, length, buffer);
|
||||
#else
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment)
|
||||
for (auto e = start; e < stop; e++)
|
||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||
};
|
||||
|
||||
|
|
|
@ -58,6 +58,7 @@ namespace nd4j {
|
|||
virtual void putVariable(int id, Variable *variable);
|
||||
virtual void putVariable(int id, NDArray *array);
|
||||
virtual void putVariable(int id, int idx, NDArray *array);
|
||||
virtual void putVariable(int id, int idx, NDArray &array);
|
||||
virtual void putVariable(int id, int idx, Variable *array);
|
||||
|
||||
virtual void replaceVariable(Variable *variable);
|
||||
|
|
|
@ -100,6 +100,7 @@ namespace nd4j {
|
|||
virtual void putVariable(int id, Variable *variable);
|
||||
virtual void putVariable(int id, NDArray *array);
|
||||
virtual void putVariable(int id, int idx, NDArray *array);
|
||||
virtual void putVariable(int id, int idx, NDArray &array);
|
||||
virtual void putVariable(int id, int idx, Variable *array);
|
||||
|
||||
virtual void dropVariable(std::pair<int,int> &pair);
|
||||
|
|
|
@ -1088,8 +1088,23 @@ namespace nd4j {
|
|||
if (e < node->input()->size() - 1)
|
||||
nd4j_printf(", ", "");
|
||||
}
|
||||
|
||||
if (node->opType() == OpType_CUSTOM) {
|
||||
auto ctx = node->protoContext();
|
||||
if (ctx->getIArguments()->size() > 0) {
|
||||
printf("]; iArgs: [");
|
||||
|
||||
for (int e = 0; e < ctx->getIArguments()->size(); e++) {
|
||||
printf("%i", ctx->getIArguments()->at(e));
|
||||
if (e < ctx->getIArguments()->size() - 1)
|
||||
nd4j_printf(", ", "");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nd4j_printf("]; \n", "");
|
||||
|
||||
|
||||
// printf("\n");
|
||||
fflush(stdout);
|
||||
}
|
||||
|
|
|
@ -60,8 +60,11 @@ namespace nd4j {
|
|||
result->_name = this->_name;
|
||||
result->_index = this->_index;
|
||||
|
||||
if (this->_ndarray != nullptr)
|
||||
if (this->_ndarray != nullptr) {
|
||||
result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering()));
|
||||
result->_readOnly = false;
|
||||
result->_removable = true;
|
||||
}
|
||||
|
||||
if (this->_list != nullptr)
|
||||
result->_list = this->_list->clone();
|
||||
|
|
|
@ -191,6 +191,9 @@ namespace nd4j {
|
|||
_current->putVariable(id, array);
|
||||
}
|
||||
|
||||
void nd4j::graph::VariableProxy::putVariable(int id, int idx, NDArray &array) {
|
||||
_current->putVariable(id, idx, array);
|
||||
}
|
||||
|
||||
void VariableProxy::putVariable(int id, int idx, NDArray *array) {
|
||||
_current->putVariable(id, idx, array);
|
||||
|
|
|
@ -263,19 +263,19 @@ namespace nd4j {
|
|||
void nd4j::graph::VariableSpace::putVariable(int id, Variable *variable) {
|
||||
// we don't want to add variables more then once
|
||||
if (_variables.count(id) > 0 || _temporary.count(id) > 0) {
|
||||
// nd4j_verbose("Trying to update variable for node_%i\n", id);
|
||||
|
||||
auto local = id < 0 ? _variables.at(id) : _temporary.at(id);
|
||||
|
||||
if (!local->hasNDArray() && variable->hasNDArray()) {
|
||||
// nd4j_verbose("Saving variable for node_%i\n", id);
|
||||
local->setNDArray(variable->getNDArray());
|
||||
|
||||
// we're inheriting this from Variable
|
||||
local->markReadOnly(variable->isReadOnly());
|
||||
local->markRemovable(variable->isRemovable());
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
//nd4j_debug("Adding Variable to Space: id: %i; Array is null: %i;\n", id, variable->getNDArray() == nullptr);
|
||||
|
||||
_varmap.lock();
|
||||
|
||||
_handles->emplace_back(variable);
|
||||
|
@ -314,6 +314,21 @@ namespace nd4j {
|
|||
}
|
||||
}
|
||||
|
||||
void nd4j::graph::VariableSpace::putVariable(int id, int idx, NDArray &array) {
|
||||
auto *var = new nd4j::graph::Variable(&array, "", id, idx);
|
||||
var->markRemovable(false);
|
||||
var->markReadOnly(true);
|
||||
|
||||
// let's see if this op needs
|
||||
bool d = this->hasVariable(id, idx);
|
||||
|
||||
this->putVariable(id, var);
|
||||
|
||||
// if var for this nodeid already exists - we'll just delete variable
|
||||
if (d)
|
||||
delete var;
|
||||
}
|
||||
|
||||
void nd4j::graph::VariableSpace::putVariable(int id, NDArray *array) {
|
||||
auto *var = new nd4j::graph::Variable(array);
|
||||
this->putVariable(id, var);
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <pointercast.h>
|
||||
#include <dll.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace nd4j {
|
||||
namespace graph {
|
||||
|
@ -65,6 +66,9 @@ namespace nd4j {
|
|||
|
||||
// total amount of memory used during execution
|
||||
Nd4jLong _memoryTotal = 0L;
|
||||
|
||||
std::vector<std::string> _inputShapes;
|
||||
std::vector<std::string> _outputShapes;
|
||||
public:
|
||||
NodeProfile() = default;
|
||||
~NodeProfile() = default;
|
||||
|
@ -84,10 +88,15 @@ namespace nd4j {
|
|||
void setObjectsSize(Nd4jLong bytes);
|
||||
void setTotalSize(Nd4jLong bytes);
|
||||
|
||||
Nd4jLong getActivationsSize();
|
||||
Nd4jLong getTemporarySize();
|
||||
Nd4jLong getObjectsSize();
|
||||
Nd4jLong getTotalSize();
|
||||
void addInputShape(Nd4jLong *shapeInfo);
|
||||
void addOutputShape(Nd4jLong *shapeInfo);
|
||||
|
||||
Nd4jLong getActivationsSize() const;
|
||||
Nd4jLong getTemporarySize() const;
|
||||
Nd4jLong getObjectsSize() const;
|
||||
Nd4jLong getTotalSize() const;
|
||||
|
||||
Nd4jLong getExecutionTime() const;
|
||||
|
||||
std::string& name();
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include <graph/profiling/GraphProfile.h>
|
||||
#include <helpers/logger.h>
|
||||
#include <chrono>
|
||||
#include <templatemath.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace nd4j {
|
||||
namespace graph {
|
||||
|
@ -184,9 +186,26 @@ namespace nd4j {
|
|||
if (_profiles.empty())
|
||||
nd4j_printf("No nodes in graph\n","");
|
||||
|
||||
for (auto v: _profiles)
|
||||
// printint out stuff
|
||||
std::vector<NodeProfile*> sorted;
|
||||
for (auto v: _profiles) {
|
||||
v->printOut();
|
||||
|
||||
sorted.emplace_back(v);
|
||||
}
|
||||
|
||||
if (_profiles.size() > 1) {
|
||||
// building hot spots
|
||||
std::sort(sorted.begin(), sorted.end(), [](const NodeProfile *a, const NodeProfile *b) -> bool {
|
||||
return a->getExecutionTime() > b->getExecutionTime();
|
||||
});
|
||||
|
||||
nd4j_printf("\nTop 30 reports by EXEC:\n", "");
|
||||
auto limit = nd4j::math::nd4j_min<int>(30, sorted.size());
|
||||
for (int e = 0; e < limit; e++) {
|
||||
sorted[e]->printOut();
|
||||
}
|
||||
}
|
||||
|
||||
nd4j_printf("\nSpecial timers:\n", "");
|
||||
if (_timings.empty())
|
||||
nd4j_printf("No special timers were set\n","");
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace nd4j {
|
|||
// graph->printOut();
|
||||
|
||||
// warm up
|
||||
for (int e = 0; e < 1000; e++) {
|
||||
for (int e = 0; e < iterations; e++) {
|
||||
FlowPath fp;
|
||||
|
||||
auto _vs = varSpace->clone();
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include <helpers/logger.h>
|
||||
#include <graph/profiling/NodeProfile.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace graph {
|
||||
|
@ -35,9 +36,23 @@ namespace nd4j {
|
|||
nd4j_printf(" Memory: ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", _memoryActivations / _merges, _memoryTemporary / _merges, _memoryObjects / _merges, _memoryTotal / _merges);
|
||||
nd4j_printf(" Time: PREP: %lld ns; EXEC: %lld ns; TTL: %lld ns;\n", _preparationTime / _merges, _executionTime / _merges, _totalTime / _merges);
|
||||
nd4j_printf(" PREP: INPUT: %lld ns; SHAPE: %lld ns; ARRAY: %lld ns;\n", _inputTime / _merges, _shapeTime / _merges, _arrayTime / _merges);
|
||||
|
||||
std::string inputs;
|
||||
std::string outputs;
|
||||
|
||||
int cnt = 0;
|
||||
for (const auto &v: _inputShapes)
|
||||
inputs += v + " ";
|
||||
|
||||
for (const auto &v: _outputShapes)
|
||||
outputs += v + " ";
|
||||
|
||||
|
||||
nd4j_printf(" Inputs: %s\n", inputs.c_str());
|
||||
nd4j_printf(" Outputs: %s\n", outputs.c_str());
|
||||
};
|
||||
|
||||
Nd4jLong NodeProfile::getActivationsSize() {
|
||||
Nd4jLong NodeProfile::getActivationsSize() const {
|
||||
return _memoryActivations;
|
||||
}
|
||||
|
||||
|
@ -53,15 +68,15 @@ namespace nd4j {
|
|||
_inputTime = time;
|
||||
}
|
||||
|
||||
Nd4jLong NodeProfile::getTemporarySize() {
|
||||
Nd4jLong NodeProfile::getTemporarySize() const{
|
||||
return _memoryTemporary;
|
||||
}
|
||||
|
||||
Nd4jLong NodeProfile::getObjectsSize() {
|
||||
Nd4jLong NodeProfile::getObjectsSize() const{
|
||||
return _memoryObjects;
|
||||
}
|
||||
|
||||
Nd4jLong NodeProfile::getTotalSize() {
|
||||
Nd4jLong NodeProfile::getTotalSize() const{
|
||||
return _memoryTotal;
|
||||
}
|
||||
|
||||
|
@ -97,6 +112,18 @@ namespace nd4j {
|
|||
_memoryTotal = bytes;
|
||||
}
|
||||
|
||||
Nd4jLong NodeProfile::getExecutionTime() const {
|
||||
return _executionTime;
|
||||
}
|
||||
|
||||
void NodeProfile::addInputShape(Nd4jLong *shapeInfo) {
|
||||
_inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
|
||||
}
|
||||
|
||||
void NodeProfile::addOutputShape(Nd4jLong *shapeInfo) {
|
||||
_outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
|
||||
}
|
||||
|
||||
void NodeProfile::merge(NodeProfile *other) {
|
||||
_merges += other->_merges;
|
||||
_memoryObjects += other->_memoryObjects;
|
||||
|
@ -110,6 +137,9 @@ namespace nd4j {
|
|||
_shapeTime += other->_shapeTime;
|
||||
_arrayTime += other->_arrayTime;
|
||||
_inputTime += other->_inputTime;
|
||||
|
||||
_inputShapes = other->_inputShapes;
|
||||
_outputShapes = other->_outputShapes;
|
||||
}
|
||||
|
||||
std::string& NodeProfile::name() {
|
||||
|
@ -129,6 +159,9 @@ namespace nd4j {
|
|||
_shapeTime = other->_shapeTime;
|
||||
_arrayTime = other->_arrayTime;
|
||||
_inputTime = other->_inputTime;
|
||||
|
||||
_inputShapes = other->_inputShapes;
|
||||
_outputShapes = other->_outputShapes;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -37,12 +37,13 @@ namespace nd4j {
|
|||
class ND4J_EXPORT LoopKind {
|
||||
|
||||
public:
|
||||
enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON};
|
||||
enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D };
|
||||
|
||||
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
|
||||
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
|
||||
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
|
||||
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||
|
||||
};
|
||||
|
||||
|
@ -82,6 +83,57 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd
|
|||
return COMMON;
|
||||
}
|
||||
|
||||
LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
|
||||
auto xRank = shape::rank(xShapeInfo);
|
||||
auto yRank = shape::rank(yShapeInfo);
|
||||
auto zRank = shape::rank(zShapeInfo);
|
||||
|
||||
auto xOrder = shape::order(xShapeInfo);
|
||||
auto yOrder = shape::order(yShapeInfo);
|
||||
auto zOrder = shape::order(zShapeInfo);
|
||||
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||
|
||||
bool bNDLoopsRanks = (xRank == zRank && yRank <= xRank && yRank >= 2);
|
||||
|
||||
int countUnityDimsInY = 0, countUnityDimsInX = 0;
|
||||
for (int i = 0; i < xRank; i++) {
|
||||
if (i < yRank)
|
||||
countUnityDimsInY += (1 == shape::sizeAt(yShapeInfo, i)) ? 1 : 0;
|
||||
countUnityDimsInX += (1 == shape::sizeAt(xShapeInfo, i)) ? 1 : 0;
|
||||
}
|
||||
|
||||
bool bNotCommonVectorCase = (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1);
|
||||
|
||||
if (3 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
|
||||
return nd4j::LoopKind::BROADCAST_3D;
|
||||
if (4 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
|
||||
return nd4j::LoopKind::BROADCAST_4D;
|
||||
if (5 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
|
||||
return nd4j::LoopKind::BROADCAST_5D;
|
||||
|
||||
|
||||
if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) {
|
||||
// we validate that shapes are equal till the last dim
|
||||
for (int e = 0; e < xRank - 1; e++) {
|
||||
if (xShapeInfo[e+1] != yShapeInfo[e+1])
|
||||
return COMMON;
|
||||
}
|
||||
|
||||
// now, if one of the shapes has 1 as last dim
|
||||
auto detect = xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0;
|
||||
|
||||
if (detect == 1)
|
||||
return nd4j::LoopKind::BROADCAST_SCALAR_Y;
|
||||
else if (detect == -1)
|
||||
return nd4j::LoopKind::BROADCAST_SCALAR_X;
|
||||
}
|
||||
|
||||
return nd4j::LoopKind::COMMON;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
|
||||
|
||||
|
|
|
@ -30,15 +30,15 @@
|
|||
|
||||
namespace nd4j {
|
||||
class ND4J_EXPORT ShapeBuilders {
|
||||
public:
|
||||
public:
|
||||
static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr);
|
||||
|
||||
|
||||
static Nd4jLong* createVectorShapeInfo(const nd4j::DataType dataType, const Nd4jLong length, nd4j::memory::Workspace* workspace = nullptr);
|
||||
|
||||
/**
|
||||
* create shapeInfo for given order basing on shape stored in shapeOnly vector
|
||||
* memory allocation for shapeInfo is on given workspace
|
||||
*/
|
||||
*/
|
||||
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace = nullptr);
|
||||
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong>& shapeOnly, memory::Workspace* workspace = nullptr);
|
||||
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::initializer_list<Nd4jLong>& shapeOnly, memory::Workspace* workspace = nullptr);
|
||||
|
@ -51,6 +51,13 @@ namespace nd4j {
|
|||
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
||||
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
||||
|
||||
/**
|
||||
* allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides
|
||||
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
|
||||
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
|
||||
*/
|
||||
static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr);
|
||||
|
||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
|
||||
|
||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);
|
||||
|
|
|
@ -50,11 +50,13 @@ namespace nd4j {
|
|||
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
|
||||
|
||||
// evaluate shapeInfo of permuted array
|
||||
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
||||
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
|
||||
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
|
||||
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
||||
|
||||
// evaluate shapeInfo of transposed array
|
||||
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace);
|
||||
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
|
||||
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
|
||||
|
||||
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
|
||||
|
||||
|
@ -97,6 +99,8 @@ namespace nd4j {
|
|||
static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo);
|
||||
static std::string strideAsString(const NDArray* array);
|
||||
|
||||
static std::string shapeInfoAsString(const Nd4jLong* shapeInfo);
|
||||
|
||||
static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo);
|
||||
|
||||
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
|
||||
|
@ -176,6 +180,17 @@ namespace nd4j {
|
|||
return (numStrings + 1) * sizeof(Nd4jLong);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method selects strides based on dimentions required for broadcasting
|
||||
* @param const pointer to input (Y) shape info for strides selection
|
||||
* @param rank of input (X) to broadcasting
|
||||
* @param dimentions size
|
||||
* @param const pointer to dimentions for broadcasting
|
||||
* @param pointer to output strides have to be pre allocated by 0
|
||||
* @return
|
||||
*/
|
||||
static void copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides);
|
||||
|
||||
/*
|
||||
* check whether arr1/arr2 is sub-array of arr2/arr1,
|
||||
* this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1
|
||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
|||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
|
||||
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
|
||||
|
||||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
||||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; // shape of sub-arrays (same for all for them)
|
||||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||
|
||||
if (numOfSubArrs > 0)
|
||||
|
|
|
@ -49,7 +49,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
|||
case nd4j::LoopKind::EWS1: {
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||
auto indexValue = OpType::startingIndexValue(tad);
|
||||
|
||||
|
@ -70,7 +70,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
|||
case nd4j::LoopKind::EWSNONZERO: {
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||
auto indexValue = OpType::startingIndexValue(tad);
|
||||
|
||||
|
@ -91,7 +91,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
|||
case nd4j::LoopKind::RANK1: {
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||
auto indexValue = OpType::startingIndexValue(tad);
|
||||
|
||||
|
@ -114,7 +114,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
|||
shape::updateStrides(2, tadShape, newStride, 'c');
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||
auto indexValue = OpType::startingIndexValue(tad);
|
||||
|
||||
|
@ -141,7 +141,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
|||
shape::updateStrides(3, tadShape, newStride, 'c');
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||
auto indexValue = OpType::startingIndexValue(tad);
|
||||
|
||||
|
@ -170,7 +170,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
|||
shape::updateStrides(4, tadShape, newStride, 'c');
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||
auto indexValue = OpType::startingIndexValue(tad);
|
||||
|
||||
|
@ -201,7 +201,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
|||
shape::updateStrides(5, tadShape, newStride, 'c');
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||
auto indexValue = OpType::startingIndexValue(tad);
|
||||
|
||||
|
@ -234,7 +234,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
|||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||
auto indexValue = OpType::startingIndexValue(tad);
|
||||
|
||||
|
@ -258,7 +258,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
|||
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||
auto indexValue = OpType::startingIndexValue(tad);
|
||||
|
||||
|
@ -284,7 +284,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
|||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||
auto indexValue = OpType::startingIndexValue(tad);
|
||||
|
||||
|
|
|
@ -43,23 +43,30 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N
|
|||
|
||||
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||
|
||||
NDArray aPR = a->permute(permutAt);
|
||||
NDArray bPR = b->permute(permutBt);
|
||||
// check whether permutation is necessary
|
||||
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
|
||||
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
|
||||
|
||||
// check whether reshape is necessary
|
||||
if(!aPR.isSameShape(shapeAt))
|
||||
aPR.reshapei( shapeAt);
|
||||
if(!bPR.isSameShape(shapeBt))
|
||||
bPR.reshapei( shapeBt);
|
||||
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
|
||||
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
|
||||
|
||||
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
|
||||
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
|
||||
|
||||
c->reshapei(outShape);
|
||||
|
||||
if(aP != aPR)
|
||||
delete aPR;
|
||||
if(bP != bPR)
|
||||
delete bPR;
|
||||
if(a != aP)
|
||||
delete aP;
|
||||
if(b != bP)
|
||||
delete bP;
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) {
|
||||
|
||||
|
@ -67,32 +74,38 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
|
|||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||
ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt);
|
||||
|
||||
NDArray *cP(c), *cPR(c);
|
||||
|
||||
// check whether permutation is required
|
||||
if(!permutForC.empty())
|
||||
cP = new NDArray(c->permute(permutForC));
|
||||
NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC));
|
||||
|
||||
auto aPR = a->permute(permutAt);
|
||||
auto bPR = b->permute(permutBt);
|
||||
// check whether permutation is necessary
|
||||
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
|
||||
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
|
||||
|
||||
// check whether reshape is necessary
|
||||
if(!aPR.isSameShape(shapeAt))
|
||||
aPR.reshapei(shapeAt);
|
||||
if(!bPR.isSameShape(shapeBt))
|
||||
bPR.reshapei(shapeBt);
|
||||
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
|
||||
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
|
||||
|
||||
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
|
||||
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
|
||||
std::vector<Nd4jLong> requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)};
|
||||
|
||||
mmul(&aPR, &bPR, cPR, 1.0, 0.0);
|
||||
NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false));
|
||||
|
||||
mmul(aPR, bPR, cPR, 1.0, 0.0);
|
||||
|
||||
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
|
||||
cP->assign(cPR);
|
||||
|
||||
if(cPR != c)
|
||||
if(aP != aPR)
|
||||
delete aPR;
|
||||
if(bP != bPR)
|
||||
delete bPR;
|
||||
if(a != aP)
|
||||
delete aP;
|
||||
if(b != bP)
|
||||
delete bP;
|
||||
|
||||
if(cP != cPR)
|
||||
delete cPR;
|
||||
if(cP != c)
|
||||
if(c != cP)
|
||||
delete cP;
|
||||
}
|
||||
|
||||
|
@ -129,7 +142,7 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
|
|||
if(!whatToDoWithC.empty()) {
|
||||
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
|
||||
for(int i = 0; i < cArrs.size()-1; ++i)
|
||||
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
||||
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false)); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
||||
}
|
||||
|
||||
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
|
||||
|
@ -208,7 +221,7 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
|
|||
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
|
||||
if(isAVector && bRank == 2) {
|
||||
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
|
||||
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N}
|
||||
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N}
|
||||
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
|
||||
delete A2;
|
||||
delete C2;
|
||||
|
|
|
@ -139,5 +139,15 @@ namespace nd4j {
|
|||
return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) {
|
||||
|
||||
Nd4jLong *outShapeInfo = nullptr;
|
||||
ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong);
|
||||
|
||||
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo);
|
||||
|
||||
return outShapeInfo;
|
||||
}
|
||||
|
||||
}
|
|
@ -75,10 +75,23 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
|
|||
permutBt = axesB;
|
||||
permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());
|
||||
|
||||
// if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case
|
||||
uint i1, i2;
|
||||
for(i1 = 0; i1 < aRank; ++i1)
|
||||
if(permutAt[i1] != i1)
|
||||
break;
|
||||
if(i1 == aRank)
|
||||
permutAt = {};
|
||||
for(i2 = 0; i2 < bRank; ++i2)
|
||||
if(permutBt[i2] != i2)
|
||||
break;
|
||||
if(i2 == bRank)
|
||||
permutBt = {};
|
||||
|
||||
Nd4jLong n2 = 1;
|
||||
for (int i = 0; i < axeAsize; i++)
|
||||
n2 *= aShapeInfo[axesA[i] + 1];
|
||||
shapeAt = {-1, n2};
|
||||
shapeAt = {shape::length(aShapeInfo) / n2, n2};
|
||||
|
||||
std::vector<Nd4jLong> oldShapeA;
|
||||
oldShapeA.resize(list_A.size());
|
||||
|
@ -89,7 +102,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
|
|||
Nd4jLong n3 = 1;
|
||||
for (int i = 0; i < axeBsize; i++)
|
||||
n3 *= bShapeInfo[axesB[i] + 1];
|
||||
shapeBt = {n3, -1};
|
||||
shapeBt = {n3, shape::length(bShapeInfo) / n3};
|
||||
|
||||
std::vector<Nd4jLong> oldShapeB;
|
||||
oldShapeB.resize(list_B.size());
|
||||
|
@ -300,32 +313,37 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
|||
return outShape;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// evaluate shapeInfo of permuted array
|
||||
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
||||
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
|
||||
|
||||
if (!arr.nonNull())
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr!");
|
||||
if (!arr.nonNull())
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
|
||||
|
||||
if (rank != arr.rankOf())
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable!");
|
||||
if (rank != arr.rankOf())
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
|
||||
|
||||
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
||||
// allocate memory for new array - shapeInfo
|
||||
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
||||
|
||||
Nd4jLong *shapeInfoNew = nullptr;
|
||||
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
|
||||
// copy arr _shapeInfo into new array
|
||||
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
|
||||
// perform buffer permutation
|
||||
shape::doPermuteShapeInfo(shapeInfoNew, dimensions);
|
||||
// allocate memory for new array - shapeInfo
|
||||
Nd4jLong *shapeInfoNew = nullptr;
|
||||
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
|
||||
|
||||
ShapeDescriptor descriptor(shapeInfoNew);
|
||||
RELEASE(shapeInfoNew, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
}
|
||||
// copy arr _shapeInfo into new array
|
||||
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
|
||||
|
||||
// perform buffer permutation
|
||||
shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());
|
||||
|
||||
if(setContigStrides)
|
||||
shape::updateStrides(shapeInfoNew, arr.ordering());
|
||||
|
||||
ShapeDescriptor descriptor(shapeInfoNew);
|
||||
|
||||
RELEASE(shapeInfoNew, workspace);
|
||||
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// evaluate shapeInfo of permuted array
|
||||
|
@ -337,14 +355,14 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// evaluate shapeInfo of transposed array
|
||||
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
||||
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
|
||||
|
||||
int rank = arr.rankOf();
|
||||
std::vector<int> dimensions(rank);
|
||||
for (int i = 0; i < rank; ++i)
|
||||
dimensions[i] = rank - 1 - i;
|
||||
|
||||
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace);
|
||||
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, setContigStrides);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -653,6 +671,26 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd
|
|||
return result;
|
||||
}
|
||||
|
||||
std::string ShapeUtils::shapeInfoAsString(const Nd4jLong* shapeInfo) {
|
||||
|
||||
if(!shapeInfo)
|
||||
throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !");
|
||||
|
||||
std::string result;
|
||||
|
||||
int len = shape::shapeInfoLength(shapeInfo[0]);
|
||||
|
||||
result.append("[");
|
||||
for (int e = 0; e < len; e++) {
|
||||
result += flatbuffers::NumToString(shapeInfo[e]);
|
||||
if (e < len - 1)
|
||||
result.append(", ");
|
||||
}
|
||||
result.append("]");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
std::string ShapeUtils::shapeAsString(const int rank, const Nd4jLong* shapeInfo) {
|
||||
if(!shapeInfo)
|
||||
|
@ -1019,6 +1057,29 @@ std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const
|
|||
return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
|
||||
}
|
||||
|
||||
void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides) {
|
||||
|
||||
int yRank = shape::rank(inShapeInfo);
|
||||
auto yOrigStride = shape::stride(inShapeInfo);
|
||||
|
||||
if (yRank == nRank) {
|
||||
for (int i = 0; i < yRank; ++i) {
|
||||
// x[2,3,4] * y[2,1,4] = z[2,3,4]
|
||||
outStrides[i] = (1 == shape::sizeAt(inShapeInfo, i)) ? 0 : yOrigStride[i];
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
auto dimEx = nd4j::ShapeUtils::evalDimsToExclude(nRank, dimsSize, dims);
|
||||
|
||||
for (int i = 0, it = 0; i < nRank; ++i) {
|
||||
auto nCount = std::count(dimEx.cbegin(), dimEx.cend(), i);
|
||||
outStrides[i] = (0 == nCount) ? yOrigStride[it++] : 0;
|
||||
if (it == yRank)
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/*
|
||||
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -40,6 +40,7 @@
|
|||
#endif
|
||||
|
||||
#include <helpers/TAD.h>
|
||||
#include <helpers/LoopKind.h>
|
||||
|
||||
#include "legacy_ops.h"
|
||||
|
||||
|
@ -122,6 +123,7 @@ namespace functions {
|
|||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ,
|
||||
nd4j::LoopKind::Kind loopKind,
|
||||
uint64_t start,
|
||||
uint64_t stop);
|
||||
|
||||
|
@ -149,6 +151,7 @@ namespace functions {
|
|||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ,
|
||||
nd4j::LoopKind::Kind loopKind,
|
||||
uint64_t start,
|
||||
uint64_t stop);
|
||||
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <loops/TrueBroadcastHelper.h>
|
||||
#include <ops/ops.h>
|
||||
|
@ -24,226 +24,268 @@
|
|||
|
||||
using namespace simdOps;
|
||||
|
||||
namespace nd4j {
|
||||
namespace helpers {
|
||||
namespace nd4j {
|
||||
namespace helpers {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank &&
|
||||
1 == yArr.ews() && 'c' == yArr.ordering() &&
|
||||
1 == zArr.ews() && 'c' == zArr.ordering());
|
||||
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() &&
|
||||
1 == yArr.ews() && 'c' == yArr.ordering() &&
|
||||
1 == zArr.ews() && 'c' == zArr.ordering());
|
||||
|
||||
if (bSpecialCase) {
|
||||
auto yLen = (uint32_t)yArr.lengthOf();
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (uint32_t i = start; i < stop; i++) {
|
||||
auto rZ = z + (i * yLen);
|
||||
auto v = x[i];
|
||||
for (uint32_t j = 0; j < yLen; j++) {
|
||||
rZ[j] = OpType::op(v, y[j]);
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
|
||||
return;
|
||||
if (bSpecialCase && yArr.isColumnVector() && 1 == xArr.sizeAt(-1) ) {
|
||||
auto yLen = (uint32_t)yArr.lengthOf();
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (uint32_t i = start; i < stop; i++) {
|
||||
auto rZ = z + (i * yLen);
|
||||
auto v = x[i];
|
||||
for (uint32_t j = 0; j < yLen; j++) {
|
||||
rZ[j] = OpType::op(v, y[j]);
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
auto yShapeInt = yArr.getShapeAsVectorInt();
|
||||
auto xShapeInt = xArr.getShapeAsVectorInt();
|
||||
auto nCountY = std::count_if(yShapeInt.cbegin(), yShapeInt.cend(), [](int i) { return i == 1; });
|
||||
auto nCountX = std::count_if(xShapeInt.cbegin(), xShapeInt.cend(), [](int i) { return i == 1; });
|
||||
|
||||
bool bSpecialCase2 = (xRank == zRank && yRank == zRank && 1 == xArr.sizeAt(-1) && 1 == yArr.sizeAt(-2) && 1 == nCountY && 1 == nCountX);
|
||||
|
||||
if (bSpecialCase && bSpecialCase2) {
|
||||
|
||||
int zDim1 = zArr.sizeAt(-2);
|
||||
int zDim2 = zArr.sizeAt(-1);
|
||||
|
||||
int nLen = zArr.lengthOf() / yArr.sizeAt(-1);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (uint32_t total = start; total < stop; total++) {
|
||||
|
||||
uint32_t i = total / zDim1;
|
||||
uint32_t j = total % zDim1;
|
||||
|
||||
uint32_t index = (i * zDim1) + j;
|
||||
auto rZ = z + (index * zDim2);
|
||||
auto rY = y + (i * zDim2);
|
||||
auto rX = x[index];
|
||||
|
||||
for (uint32_t n = 0; n < zDim2; n++) {
|
||||
rZ[n] = OpType::op(rX, rY[n]);
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, nLen, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||
X* z = reinterpret_cast<X*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
|
||||
}
|
||||
|
||||
/*
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
|
||||
*/
|
||||
}
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
} else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
} else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
} else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
} else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||
X* z = reinterpret_cast<X*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
} else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
} else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
|
||||
}
|
||||
|
||||
/*
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
|
||||
*/
|
||||
}
|
||||
}
|
|
@ -25,6 +25,7 @@
|
|||
#include <LoopKind.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
#include <execution/Threads.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
|
||||
using namespace simdOps;
|
||||
|
||||
|
@ -75,6 +76,7 @@ namespace functions {
|
|||
Nd4jLong *xTadOffset,
|
||||
Nd4jLong *zTadShapeInfo,
|
||||
Nd4jLong *zTadOffset,
|
||||
nd4j::LoopKind::Kind loopKind,
|
||||
uint64_t start,
|
||||
uint64_t stop) {
|
||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
||||
|
@ -88,7 +90,7 @@ namespace functions {
|
|||
xTadShapeInfo,
|
||||
xTadOffset,
|
||||
zTadShapeInfo,
|
||||
zTadOffset, start, stop), BROADCAST_OPS);
|
||||
zTadOffset, loopKind, start, stop), BROADCAST_OPS);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
|
@ -105,6 +107,7 @@ namespace functions {
|
|||
Nd4jLong *xTadOffset,
|
||||
Nd4jLong *zTadShapeInfo,
|
||||
Nd4jLong *zTadOffset,
|
||||
nd4j::LoopKind::Kind loopKind,
|
||||
uint64_t start,
|
||||
uint64_t stop) {
|
||||
|
||||
|
@ -142,7 +145,14 @@ namespace functions {
|
|||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||
auto zEws = shape::elementWiseStride(zTadShapeInfo);
|
||||
|
||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
||||
|
||||
const nd4j::LoopKind::Kind kindOfLoop =
|
||||
(loopKind == nd4j::LoopKind::BROADCAST_SCALAR_X ||
|
||||
loopKind == nd4j::LoopKind::BROADCAST_SCALAR_Y ||
|
||||
loopKind == nd4j::LoopKind::BROADCAST_3D ||
|
||||
loopKind == nd4j::LoopKind::BROADCAST_4D ||
|
||||
loopKind == nd4j::LoopKind::BROADCAST_5D)
|
||||
? loopKind : nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
||||
|
||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
|
@ -163,6 +173,131 @@ namespace functions {
|
|||
for (unsigned int f = 0; f < tadLength; f++)
|
||||
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
||||
}
|
||||
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_X){
|
||||
// this loop effectively turns broadcast into series of scalar ops
|
||||
auto loopLength = yShapeInfo[shape::rank(yShapeInfo)];
|
||||
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto oY = y + (i * loopLength);
|
||||
auto oZ = z + (i * loopLength);
|
||||
|
||||
const auto oX = x[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int f = 0; f < loopLength; f++)
|
||||
oZ[f] = OpType::op(oX, oY[f]);
|
||||
}
|
||||
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_Y){
|
||||
// this loop effectively turns broadcast into series of scalar ops
|
||||
auto loopLength = xShapeInfo[shape::rank(xShapeInfo)];
|
||||
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto oX = x + (i * loopLength);
|
||||
auto oZ = z + (i * loopLength);
|
||||
|
||||
const auto oY = y[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int f = 0; f < loopLength; f++)
|
||||
oZ[f] = OpType::op(oX[f], oY);
|
||||
}
|
||||
}
|
||||
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_3D) {
|
||||
|
||||
int xRank = shape::rank(xShapeInfo);
|
||||
int yRank = shape::rank(yShapeInfo);
|
||||
|
||||
auto xStrides = shape::stride(xShapeInfo);
|
||||
auto zStrides = shape::stride(zShapeInfo);
|
||||
|
||||
Nd4jLong yStrides[3] = { 0,0,0 };
|
||||
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
|
||||
|
||||
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
|
||||
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
|
||||
|
||||
for (uint32_t index0 = start; index0 < stop; index0++) {
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint32_t index1 = 0; index1 < nSize1; index1++) {
|
||||
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
|
||||
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2);
|
||||
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2);
|
||||
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2);
|
||||
*rZ = OpType::op(*rX, *rY);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_4D) {
|
||||
|
||||
int xRank = shape::rank(xShapeInfo);
|
||||
int yRank = shape::rank(yShapeInfo);
|
||||
|
||||
auto xStrides = shape::stride(xShapeInfo);
|
||||
auto zStrides = shape::stride(zShapeInfo);
|
||||
|
||||
Nd4jLong yStrides[4] = { 0,0,0,0 };
|
||||
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
|
||||
|
||||
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
|
||||
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
|
||||
uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3);
|
||||
|
||||
for (uint32_t i = start; i < stop; i++) {
|
||||
|
||||
uint32_t index0 = i / nSize1;
|
||||
uint32_t index1 = i % nSize1;
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
|
||||
for (uint32_t index3 = 0; index3 < nSize3; index3++) {
|
||||
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3);
|
||||
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3);
|
||||
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3);
|
||||
*rZ = OpType::op(*rX, *rY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_5D) {
|
||||
|
||||
int xRank = shape::rank(xShapeInfo);
|
||||
int yRank = shape::rank(yShapeInfo);
|
||||
|
||||
auto xStrides = shape::stride(xShapeInfo);
|
||||
auto zStrides = shape::stride(zShapeInfo);
|
||||
|
||||
Nd4jLong yStrides[5] = { 0,0,0,0,0 };
|
||||
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
|
||||
|
||||
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
|
||||
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
|
||||
uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3);
|
||||
uint32_t nSize4 = shape::sizeAt(zShapeInfo, 4);
|
||||
|
||||
for (uint32_t i = start; i < stop; i++) {
|
||||
|
||||
uint32_t index0 = i / nSize1;
|
||||
uint32_t index1 = i % nSize1;
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
|
||||
for (uint32_t index3 = 0; index3 < nSize3; index3++) {
|
||||
for (uint32_t index4 = 0; index4 < nSize4; index4++) {
|
||||
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3 + xStrides[4] * index4);
|
||||
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3 + yStrides[4] * index4);
|
||||
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3 + zStrides[4] * index4);
|
||||
|
||||
*rZ = OpType::op(*rX, *rY);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
|
|
|
@ -73,7 +73,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
|
|||
auto func = PRAGMA_THREADS_FOR {
|
||||
intermediatery[thread_id] = OpType::startingIndexValue(x);
|
||||
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
IndexValue<X> curr(x[i], i);
|
||||
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
|
||||
}
|
||||
|
@ -88,7 +88,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
|
|||
auto func = PRAGMA_THREADS_FOR {
|
||||
intermediatery[thread_id] = OpType::startingIndexValue(x);
|
||||
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||
IndexValue<X> curr(x[offset], i);
|
||||
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
|
||||
|
|
|
@ -75,7 +75,7 @@ namespace functions {
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
PRAGMA_OMP_SIMD
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||
z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ namespace functions {
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint64_t i = start; i < stop; i += increment) {
|
||||
for (uint64_t i = start; i < stop; i++) {
|
||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||
z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
|
||||
|
@ -111,7 +111,7 @@ namespace functions {
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint64_t i = start; i < stop; i += increment) {
|
||||
for (uint64_t i = start; i < stop; i++) {
|
||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||
z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments);
|
||||
|
@ -129,7 +129,7 @@ namespace functions {
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint64_t i = start; i < stop; i += increment) {
|
||||
for (uint64_t i = start; i < stop; i++) {
|
||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||
auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||
z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments);
|
||||
|
@ -149,7 +149,7 @@ namespace functions {
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint64_t i = start; i < stop; i += increment) {
|
||||
for (uint64_t i = start; i < stop; i++) {
|
||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||
|
@ -197,7 +197,7 @@ namespace functions {
|
|||
else{
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint64_t i = start; i < stop; i += increment) {
|
||||
for (uint64_t i = start; i < stop; i++) {
|
||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||
z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments);
|
||||
}
|
||||
|
@ -213,7 +213,7 @@ namespace functions {
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint64_t i = start; i < stop; i += increment) {
|
||||
for (uint64_t i = start; i < stop; i++) {
|
||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||
z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments);
|
||||
|
@ -255,7 +255,7 @@ namespace functions {
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint64_t i = start; i < stop; i += increment) {
|
||||
for (uint64_t i = start; i < stop; i++) {
|
||||
auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||
z[offset] = OpClass::op(i, length, rng, extraArguments);
|
||||
}
|
||||
|
|
|
@ -88,7 +88,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
|||
|
||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
||||
}
|
||||
};
|
||||
|
@ -98,7 +98,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
|||
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
||||
}
|
||||
|
@ -110,7 +110,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
|||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
||||
|
|
|
@ -158,7 +158,7 @@ namespace functions {
|
|||
const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto r = start; r < stop; r += increment) {
|
||||
for (auto r = start; r < stop; r++) {
|
||||
|
||||
auto tadOffsetForBlock = tadPack.primaryOffsets()[r];
|
||||
auto tx = x + tadOffsetForBlock;
|
||||
|
|
|
@ -84,7 +84,7 @@ namespace functions {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||
|
||||
for (int i = tid; i < length; i += totalThreads)
|
||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||
|
|
|
@ -89,7 +89,7 @@ namespace functions {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||
|
||||
for (int i = tid; i < length; i += totalThreads)
|
||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||
|
|
|
@ -97,7 +97,7 @@ namespace functions {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||
|
||||
for (Nd4jLong i = tid; i < length; i += totalThreads)
|
||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||
|
|
|
@ -87,7 +87,7 @@ namespace functions {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||
|
||||
for (int i = tid; i < length; i += totalThreads)
|
||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||
|
|
|
@ -89,7 +89,7 @@ namespace functions {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||
|
||||
for (int i = tid; i < length; i += totalThreads)
|
||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||
|
|
|
@ -81,7 +81,7 @@ namespace nd4j {
|
|||
|
||||
// now we actually apply quantization
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
rz[e] = static_cast<char>(nd4j::math::nd4j_round<float, char>( 1.0f * static_cast<float>(x[e]) / nd4j::math::nd4j_max<float>(amax, amin) * max_byte));
|
||||
}
|
||||
};
|
||||
|
@ -177,7 +177,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
|
|||
int flimit = limit + 4;
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
int el = x[e];
|
||||
int ael = nd4j::math::nd4j_abs<int>(el) - 1;
|
||||
z[ael] += el > 0 ? static_cast<T>(threshold) : static_cast<T>(-threshold);
|
||||
|
@ -202,7 +202,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
|
|||
auto z = reinterpret_cast<T *>(dz);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
z[i] = static_cast<T>(static_cast<float>(x[i]));
|
||||
}
|
||||
};
|
||||
|
|
|
@ -147,6 +147,9 @@ namespace nd4j {
|
|||
// returns TRUE if this op allows in-place execution
|
||||
bool allowsInplace();
|
||||
|
||||
// this method allows you to enable/disable inplace call for a given op
|
||||
void allowInplace(bool reallyAllow);
|
||||
|
||||
// this method returns opNum (applicable for legacy XYZ ops only)
|
||||
int getOpNum();
|
||||
|
||||
|
|
|
@ -27,12 +27,10 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
OP_IMPL(identity, 1, 1, true) {
|
||||
auto first = INPUT_VARIABLE(0);
|
||||
auto z = this->getZ(block);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
// just for lulz
|
||||
first->applyTransform(nd4j::transform::Identity, *z);
|
||||
|
||||
STORE_RESULT(*z);
|
||||
if (!block.isInplace())
|
||||
first->applyTransform(nd4j::transform::Identity, *z);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -60,8 +58,8 @@ namespace nd4j {
|
|||
DECLARE_TYPES(identity_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, DataType::ANY)
|
||||
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
// @author Yurii Shyrma (iuriish@yahoo.com), fully rewritten
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_matmul)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
|
@ -29,142 +29,128 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
||||
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
const int xRank = x->rankOf();
|
||||
const int yRank = y->rankOf();
|
||||
const int zRank = z->rankOf();
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
|
||||
if (transZ) {
|
||||
x = INPUT_VARIABLE(1);
|
||||
y = INPUT_VARIABLE(0);
|
||||
bool temp = transX;
|
||||
transX = !transY;
|
||||
transY = !temp;
|
||||
}
|
||||
const int xRank = x->rankOf();
|
||||
const int yRank = y->rankOf();
|
||||
const int zRank = z->rankOf();
|
||||
|
||||
const int xLastDim = transX ? -2 : -1;
|
||||
const int yLastDim = transY ? -2 : -1;
|
||||
const int xLastButOneDim = transX ? -1 : -2;
|
||||
const int yLastButOneDim = transY ? -1 : -2;
|
||||
if (transZ) {
|
||||
x = INPUT_VARIABLE(1);
|
||||
y = INPUT_VARIABLE(0);
|
||||
bool temp = transX;
|
||||
transX = !transY;
|
||||
transY = !temp;
|
||||
}
|
||||
|
||||
// ******* input validation ******* //
|
||||
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0,
|
||||
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
|
||||
xRank, yRank);
|
||||
const int xLastDim = transX ? -2 : -1;
|
||||
const int yLastDim = transY ? -2 : -1;
|
||||
const int xLastButOneDim = transX ? -1 : -2;
|
||||
const int yLastButOneDim = transY ? -1 : -2;
|
||||
|
||||
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
|
||||
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0,
|
||||
"MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",
|
||||
x->lengthOf(), y->lengthOf());
|
||||
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
|
||||
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0,
|
||||
"MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !",
|
||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
|
||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0,
|
||||
"MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !",
|
||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
} else {
|
||||
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0,
|
||||
"MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !",
|
||||
xRank, yRank, zRank);
|
||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) &&
|
||||
x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0,
|
||||
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
|
||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
|
||||
ShapeUtils::shapeAsString(z).c_str());
|
||||
// ******* input validation ******* //
|
||||
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank);
|
||||
|
||||
if (xRank > 2) // outer dims must be the same
|
||||
for (int i = 0; i < xRank - 2; ++i)
|
||||
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0,
|
||||
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
|
||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
|
||||
ShapeUtils::shapeAsString(z).c_str());
|
||||
}
|
||||
// ******* end of input validation ******* //
|
||||
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
|
||||
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", x->lengthOf(), y->lengthOf());
|
||||
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
|
||||
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
|
||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
} else {
|
||||
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank);
|
||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
|
||||
|
||||
MmulHelper::matmul(x, y, z, transX, transY);
|
||||
if (xRank > 2) // outer dims must be the same
|
||||
for (int i = 0; i < xRank - 2; ++i)
|
||||
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
|
||||
}
|
||||
// ******* end of input validation ******* //
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
MmulHelper::matmul(x, y, z, transX, transY);
|
||||
|
||||
DECLARE_SYN(mMul, matmul);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SYN(mmul, matmul);
|
||||
DECLARE_SYN(mMul, matmul);
|
||||
|
||||
DECLARE_SYN(gemm, matmul);
|
||||
DECLARE_SYN(mmul, matmul);
|
||||
|
||||
DECLARE_SYN(gemv, matmul);
|
||||
DECLARE_SYN(gemm, matmul);
|
||||
|
||||
DECLARE_SYN(dot, matmul);
|
||||
DECLARE_SYN(gemv, matmul);
|
||||
|
||||
DECLARE_SYN(dot, matmul);
|
||||
|
||||
DECLARE_SHAPE_FN(matmul) {
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(matmul) {
|
||||
|
||||
auto xShapeInfo = inputShape->at(0);
|
||||
auto yShapeInfo = inputShape->at(1);
|
||||
auto xShapeInfo = inputShape->at(0);
|
||||
auto yShapeInfo = inputShape->at(1);
|
||||
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
|
||||
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
|
||||
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
|
||||
xShapeInfo[0], yShapeInfo[0]);
|
||||
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
|
||||
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
|
||||
xShapeInfo[0], yShapeInfo[0]);
|
||||
|
||||
if (transZ) {
|
||||
xShapeInfo = inputShape->at(1);
|
||||
yShapeInfo = inputShape->at(0);
|
||||
bool temp = transX;
|
||||
transX = !transY;
|
||||
transY = !temp;
|
||||
}
|
||||
if (transZ) {
|
||||
xShapeInfo = inputShape->at(1);
|
||||
yShapeInfo = inputShape->at(0);
|
||||
bool temp = transX;
|
||||
transX = !transY;
|
||||
transY = !temp;
|
||||
}
|
||||
|
||||
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
|
||||
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
|
||||
|
||||
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
|
||||
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
|
||||
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
|
||||
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
|
||||
|
||||
auto xOrder = shape::order(xShapeInfo);
|
||||
auto yOrder = shape::order(yShapeInfo);
|
||||
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
|
||||
auto xOrder = shape::order(xShapeInfo);
|
||||
auto yOrder = shape::order(yShapeInfo);
|
||||
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
|
||||
|
||||
// we just pick the higher data type out of X and Y
|
||||
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
|
||||
// we just pick the higher data type out of X and Y
|
||||
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
|
||||
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
|
||||
return SHAPELIST(newShape);
|
||||
}
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
|
||||
return SHAPELIST(newShape);
|
||||
}
|
||||
|
||||
DECLARE_TYPES(matmul) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(matmul) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto eps = INPUT_VARIABLE(2);
|
||||
auto dldx = OUTPUT_VARIABLE(0);
|
||||
auto dldy = OUTPUT_VARIABLE(1);
|
||||
|
||||
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto eps = INPUT_VARIABLE(2);
|
||||
auto dldx = OUTPUT_VARIABLE(0);
|
||||
auto dldy = OUTPUT_VARIABLE(1);
|
||||
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
|
||||
/*
|
||||
In: x=[a,b], y=[b,c]
|
||||
|
@ -177,34 +163,35 @@ F F T [a,b] [b,c] [c,a] [c,a]
|
|||
*/
|
||||
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
|
||||
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
|
||||
nd4j::ops::matmul op;
|
||||
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
|
||||
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(matmul_bp) {
|
||||
Nd4jLong *xShapeInfo;
|
||||
Nd4jLong *yShapeInfo;
|
||||
|
||||
DECLARE_SHAPE_FN(matmul_bp) {
|
||||
Nd4jLong *xShapeInfo;
|
||||
Nd4jLong *yShapeInfo;
|
||||
COPY_SHAPE(inputShape->at(0), xShapeInfo);
|
||||
COPY_SHAPE(inputShape->at(1), yShapeInfo);
|
||||
|
||||
COPY_SHAPE(inputShape->at(0), xShapeInfo);
|
||||
COPY_SHAPE(inputShape->at(1), yShapeInfo);
|
||||
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
|
||||
}
|
||||
|
||||
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(matmul_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(2, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(1, {ALL_FLOATS});
|
||||
}
|
||||
|
||||
DECLARE_TYPES(matmul_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(2, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(1, {ALL_FLOATS});
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -21,70 +21,174 @@
|
|||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_tensormmul)
|
||||
|
||||
#include <numeric>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <MmulHelper.h>
|
||||
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
|
||||
auto a = INPUT_VARIABLE(0);
|
||||
auto b = INPUT_VARIABLE(1);
|
||||
namespace ops {
|
||||
|
||||
auto c = OUTPUT_VARIABLE(0); //
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
|
||||
|
||||
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
|
||||
auto a = INPUT_VARIABLE(0);
|
||||
auto b = INPUT_VARIABLE(1);
|
||||
|
||||
// building axes
|
||||
int axe0_size = INT_ARG(0);
|
||||
int axe1_size = INT_ARG(axe0_size+1);
|
||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||
for (int e = 0; e < axe0_size; e++)
|
||||
axes_0[e] = (int) INT_ARG(e+1);
|
||||
auto c = OUTPUT_VARIABLE(0);
|
||||
|
||||
for (int e = 0; e < axe1_size; e++)
|
||||
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
|
||||
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
|
||||
|
||||
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
|
||||
// building axes
|
||||
int axe0_size = INT_ARG(0);
|
||||
int axe1_size = INT_ARG(axe0_size+1);
|
||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||
for (int e = 0; e < axe0_size; e++)
|
||||
axes_0[e] = (int)INT_ARG(e + 1);
|
||||
|
||||
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_SYN(tensordot, tensormmul);
|
||||
for (int e = 0; e < axe1_size; e++)
|
||||
axes_1[e] = (int)INT_ARG(e + axe0_size + 2);
|
||||
|
||||
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
|
||||
|
||||
DECLARE_SHAPE_FN(tensormmul) {
|
||||
|
||||
auto aShapeInfo = inputShape->at(0);
|
||||
auto bShapeInfo = inputShape->at(1);
|
||||
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_SYN(tensordot, tensormmul);
|
||||
|
||||
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(tensormmul) {
|
||||
|
||||
// building axes
|
||||
int axe0_size = INT_ARG(0);
|
||||
int axe1_size = INT_ARG(axe0_size+1);
|
||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||
for (int e = 0; e < axe0_size; e++)
|
||||
axes_0[e] = (int) INT_ARG(e+1);
|
||||
auto aShapeInfo = inputShape->at(0);
|
||||
auto bShapeInfo = inputShape->at(1);
|
||||
|
||||
for (int e = 0; e < axe1_size; e++)
|
||||
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
|
||||
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
|
||||
|
||||
// evaluate shapes
|
||||
std::vector<int> permutAt, permutBt;
|
||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||
// building axes
|
||||
int axe0_size = INT_ARG(0);
|
||||
int axe1_size = INT_ARG(axe0_size+1);
|
||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||
for (int e = 0; e < axe0_size; e++)
|
||||
axes_0[e] = (int) INT_ARG(e+1);
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
|
||||
}
|
||||
for (int e = 0; e < axe1_size; e++)
|
||||
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
|
||||
|
||||
DECLARE_TYPES(tensormmul) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
|
||||
}
|
||||
// evaluate shapes
|
||||
std::vector<int> permutAt, permutBt;
|
||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(tensormmul) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
|
||||
|
||||
auto A = INPUT_VARIABLE(0);
|
||||
auto B = INPUT_VARIABLE(1);
|
||||
|
||||
auto dLdC = INPUT_VARIABLE(2);
|
||||
|
||||
auto dLdA = OUTPUT_VARIABLE(0);
|
||||
auto dLdB = OUTPUT_VARIABLE(1);
|
||||
|
||||
REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
|
||||
|
||||
int axe0Size = INT_ARG(0);
|
||||
int axe1Size = INT_ARG(axe0Size + 1);
|
||||
|
||||
auto Arank = A->rankOf();
|
||||
auto Brank = B->rankOf();
|
||||
auto dLdCrank = dLdC->rankOf();
|
||||
|
||||
REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0");
|
||||
|
||||
REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1");
|
||||
|
||||
// building axes
|
||||
std::vector<int> axes0(axe0Size), axes1(axe1Size);
|
||||
for (uint e = 0; e < axe0Size; e++)
|
||||
axes0[e] = (int)INT_ARG(e + 1);
|
||||
for (uint e = 0; e < axe1Size; e++)
|
||||
axes1[e] = (int)INT_ARG(e + axe0Size + 2);
|
||||
|
||||
std::vector<int> permutAt, permutBt;
|
||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||
|
||||
ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt);
|
||||
|
||||
// special case for scalar value
|
||||
if (dLdC->isScalar()) {
|
||||
|
||||
dLdA->assign((*dLdC) * *B);
|
||||
dLdB->assign((*dLdC) * *A);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
|
||||
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
|
||||
|
||||
// rank always have to be divided by 2
|
||||
std::vector<int> axesAdLdC, axesBdLdC;
|
||||
if (dLdCrank > 1) {
|
||||
axesAdLdC.resize(dLdCrank / 2);
|
||||
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
|
||||
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
|
||||
}
|
||||
else {
|
||||
axesAdLdC.push_back(0);
|
||||
axesBdLdC.push_back(0);
|
||||
}
|
||||
|
||||
// calculate dLdA
|
||||
MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt);
|
||||
|
||||
// calculate dLdB
|
||||
MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(tensormmul_bp) {
|
||||
|
||||
auto aShapeInfo = inputShape->at(0);
|
||||
auto bShapeInfo = inputShape->at(1);
|
||||
auto dLShapeInfo = inputShape->at(2);
|
||||
|
||||
REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) &&
|
||||
(ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
|
||||
|
||||
Nd4jLong* dLdAShapeInfo = nullptr;
|
||||
Nd4jLong* dLdBShapeInfo = nullptr;
|
||||
|
||||
COPY_SHAPE(aShapeInfo, dLdAShapeInfo);
|
||||
COPY_SHAPE(bShapeInfo, dLdBShapeInfo);
|
||||
|
||||
return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(tensormmul_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS
|
||||
->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||
->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||
->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||
->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -79,7 +79,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
|||
}
|
||||
|
||||
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
||||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
||||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput, false);
|
||||
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
nd4j::ops::conv2d conv2d;
|
||||
|
@ -216,10 +216,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
|||
}
|
||||
|
||||
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
||||
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput);
|
||||
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput, false);
|
||||
auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO);
|
||||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
nd4j::ops::conv2d_bp conv2dBP;
|
||||
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||
|
|
|
@ -239,7 +239,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
//----- calculation of gradO -----//
|
||||
if(gradB) {
|
||||
if(gradB->rankOf() == 2)
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
|
||||
if(gradB != OUTPUT_VARIABLE(2))
|
||||
delete gradB;
|
||||
|
|
|
@ -233,7 +233,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
|||
// ----- calculation of gradB ----- //
|
||||
if(gradB) {
|
||||
if(gradB->rankOf() == 2)
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}));
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false));
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
|
||||
if(gradB != OUTPUT_VARIABLE(2))
|
||||
delete gradB;
|
||||
|
|
|
@ -243,7 +243,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
|||
// ----- calculation of gradB ----- //
|
||||
if(gradB) {
|
||||
if(gradB->rankOf() == 2)
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
|
||||
if(gradB != OUTPUT_VARIABLE(2))
|
||||
delete gradB;
|
||||
|
|
|
@ -31,22 +31,17 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf());
|
||||
REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf());
|
||||
REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1));
|
||||
REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!",
|
||||
x->sizeAt(1), w->sizeAt(0));
|
||||
|
||||
REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", x->sizeAt(1), w->sizeAt(0));
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
//T bound = (T)0.f;
|
||||
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
|
||||
|
||||
nd4j::ops::xw_plus_b op;
|
||||
std::unique_ptr<ResultSet> result(op.evaluate({x, w, b}));
|
||||
REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data.");
|
||||
auto status = op.execute({x, w, b}, {output});
|
||||
REQUIRE_TRUE(Status::OK() == status, 0, "relu_layer: xw_plus_b op failed on input data.");
|
||||
|
||||
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
|
||||
|
||||
auto xw = result->at(0);
|
||||
xw->applyScalar(nd4j::scalar::RELU, scalar, *output);
|
||||
output->applyScalar(nd4j::scalar::RELU, scalar, *output);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
|
||||
//#include <ops/declarable/headers/parity_ops.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/image_resize.h>
|
||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(crop_and_resize, 4, 1, false, 0, 0) {
|
||||
|
|
|
@ -61,13 +61,13 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||
|
||||
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(resize_area) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto shapeList = SHAPELIST();
|
||||
auto in = inputShape->at(0);
|
||||
|
||||
Nd4jLong* outputShape;
|
||||
|
@ -90,7 +90,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank);
|
||||
|
||||
|
||||
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
|
||||
outputShape[0] = inRank;
|
||||
if (inRank == 4) {
|
||||
|
|
|
@ -62,13 +62,13 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
||||
|
||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||
|
||||
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(resize_bicubic) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto shapeList = SHAPELIST();
|
||||
auto in = inputShape->at(0);
|
||||
|
||||
Nd4jLong* outputShape;
|
||||
|
@ -82,7 +82,7 @@ namespace nd4j {
|
|||
height = newImageSize->e<int>(1);
|
||||
|
||||
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
|
||||
|
||||
|
||||
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
|
||||
outputShape[0] = inRank;
|
||||
if (inRank == 4) {
|
||||
|
|
|
@ -43,7 +43,7 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
||||
|
||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||
|
||||
if (block.width() > 1) {
|
||||
auto newImageSize = INPUT_VARIABLE(1);
|
||||
|
@ -71,7 +71,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
DECLARE_SHAPE_FN(resize_bilinear) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto shapeList = SHAPELIST();
|
||||
auto in = inputShape->at(0);
|
||||
|
||||
Nd4jLong* outputShape;
|
||||
|
@ -94,7 +94,7 @@ namespace nd4j {
|
|||
width = INT_ARG(0);
|
||||
height = INT_ARG(1);
|
||||
}
|
||||
|
||||
|
||||
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
|
||||
outputShape[0] = inRank;
|
||||
if (inRank == 4) {
|
||||
|
|
|
@ -63,13 +63,13 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
|
||||
|
||||
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||
auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||
|
||||
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(resize_nearest_neighbor) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto shapeList = SHAPELIST();
|
||||
auto in = inputShape->at(0);
|
||||
auto inRank = shape::rank(in);
|
||||
Nd4jLong* outputShape;
|
||||
|
|
|
@ -47,11 +47,12 @@ namespace nd4j {
|
|||
|
||||
shape.insert(shape.begin() + axis, 1);
|
||||
|
||||
auto tmp = input->reshape(input->ordering(), shape);
|
||||
output->assign(tmp);
|
||||
|
||||
STORE_RESULT(output);
|
||||
|
||||
if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) {
|
||||
output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset());
|
||||
} else {
|
||||
auto tmp = input->reshape(input->ordering(), shape);
|
||||
output->assign(tmp);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver119 on 29/10/17.
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
|
@ -29,80 +30,52 @@ namespace nd4j {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// here iArgs is int vector of ordered set of dimensions to be permuted
|
||||
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
|
||||
|
||||
bool replace = false;
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
std::vector<int> arguments({});
|
||||
if(origArgs.size() > 0){
|
||||
for (int e = 0; e < origArgs.size(); e++) {
|
||||
int ax = origArgs[e];
|
||||
if (ax < 0)
|
||||
ax += x->rankOf();
|
||||
|
||||
arguments.emplace_back(ax);
|
||||
}
|
||||
|
||||
replace = true;
|
||||
} else {
|
||||
for (int e = x->rankOf() - 1; e >= 0; e--)
|
||||
arguments.emplace_back(e);
|
||||
}
|
||||
|
||||
// 0D edge case
|
||||
if (x->rankOf() == 0) {
|
||||
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
if (!block.isInplace())
|
||||
output->assign(x);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if(block.isInplace()) { // in-place
|
||||
x->permutei(arguments);
|
||||
STORE_RESULT(x);
|
||||
} else {
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
auto result = x->permute(arguments);
|
||||
output->assign(result);
|
||||
STORE_RESULT(output);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(permute) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(permute) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
if (shape::rank(inputShape->at(0)) == 0) {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
|
||||
} else if (inputShape->size() == 1 && !arguments.empty()) {
|
||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
||||
} else {
|
||||
if(arguments.size() == 0){
|
||||
//Reverse dimensions
|
||||
int rank = shape::rank(inputShape->at(0));
|
||||
for (int e = rank - 1; e >= 0; e--)
|
||||
arguments.emplace_back(e);
|
||||
}
|
||||
|
||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
||||
}
|
||||
|
||||
return shapeList;
|
||||
}
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(z->isEmpty(), 0, "PERMUTE OP: when input is empty, output must also be empty");
|
||||
return Status::OK(); //No op
|
||||
}
|
||||
|
||||
if (block.width() == 1 && block.getIArguments()->size() == 0) {
|
||||
z->assign(x->transpose());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
z->assign(x->permute(permutationVector));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(permute) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(permute) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
|
||||
if (block.width() == 1 && block.getIArguments()->size() == 0)
|
||||
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
|
||||
|
||||
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
|
||||
|
||||
return SHAPELIST(outputShapeInfo);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -24,254 +24,240 @@
|
|||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// here iArgs is a vector with (optional) negative of order as first element:
|
||||
// ({-order, dim1, dim2, dim3, ...})
|
||||
CUSTOM_OP_IMPL(reshape, 1, 1, true, 0, -2) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
namespace ops {
|
||||
|
||||
if (block.width() == 1) {
|
||||
auto arguments = block.getIArguments();
|
||||
int argsSize = arguments->size();
|
||||
|
||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||
return ND4J_STATUS_OK; //No op
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// here iArgs is a vector with (optional) negative of order as first element:
|
||||
// ({-order, dim1, dim2, dim3, ...})
|
||||
CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||
return Status::OK(); //No op
|
||||
}
|
||||
|
||||
if (block.width() == 1) {
|
||||
|
||||
auto arguments = block.getIArguments();
|
||||
int argsSize = arguments->size();
|
||||
|
||||
|
||||
|
||||
int e = 1;
|
||||
char order = (char) -(*arguments)[0];
|
||||
if (order != 'c' && order != 'f') {
|
||||
order = 'c'; //x->ordering();
|
||||
e = 0;
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
|
||||
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
int e2 = e;
|
||||
for (; e < (int) arguments->size(); e++) {
|
||||
if (arguments->at(e) == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(; e2 < e; e2++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
|
||||
int e = 1;
|
||||
char order = (char) -(*arguments)[0];
|
||||
if (order != 'c' && order != 'f') {
|
||||
order = 'c'; //x->ordering();
|
||||
e = 0;
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
|
||||
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
int e2 = e;
|
||||
for (; e < (int) arguments->size(); e++) {
|
||||
if (arguments->at(e) == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(; e2 < e; e2++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
||||
shapeNew.push_back(realShape);
|
||||
}
|
||||
else{
|
||||
shapeNew.push_back(arguments->at(e));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
nd4j_printv("Reshape: new shape", shapeNew);
|
||||
}
|
||||
|
||||
if (block.isInplace()) {
|
||||
if (x->reshapei(order, shapeNew)) {
|
||||
STORE_RESULT(*x);
|
||||
return ND4J_STATUS_OK;
|
||||
}
|
||||
} else {
|
||||
auto ret = OUTPUT_VARIABLE(0);
|
||||
auto xr = x->reshape(order, shapeNew);
|
||||
ret->assign(xr);
|
||||
STORE_RESULT(*ret);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} else if (block.width() == 2) {
|
||||
auto s = INPUT_VARIABLE(1);
|
||||
|
||||
//Special case: empty.reshape(-1) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
//REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||
return Status::OK(); //No op
|
||||
}
|
||||
|
||||
char order = 'c';
|
||||
if (block.numI() > 0)
|
||||
order = (char) -INT_ARG(0);
|
||||
|
||||
std::vector<Nd4jLong> shapeNew(s->lengthOf());
|
||||
|
||||
for (int e = 0; e < (int) s->lengthOf(); e++) {
|
||||
auto dim = s->e<Nd4jLong >(e);
|
||||
if (dim == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(int e2 = 0; e2 < e; e2++){
|
||||
shapeLength *= s->e<Nd4jLong>(e2);
|
||||
}
|
||||
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
||||
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= s->e<Nd4jLong>(e2);
|
||||
}
|
||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
||||
shapeNew[e] = realShape;
|
||||
}
|
||||
else{
|
||||
shapeNew[e] = dim;
|
||||
}
|
||||
}
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
nd4j_printv("Reshape: new shape", shapeNew);
|
||||
}
|
||||
|
||||
if (block.isInplace()) {
|
||||
if (x->reshapei(order, shapeNew)) {
|
||||
STORE_RESULT(*x);
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
auto ret = OUTPUT_VARIABLE(0);
|
||||
if (s->isEmpty()) {
|
||||
// just a scalar
|
||||
ret->assign(x);
|
||||
} else {
|
||||
auto xr = x->reshape(order, shapeNew);
|
||||
ret->assign(xr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
||||
shapeNew.push_back(realShape);
|
||||
}
|
||||
else{
|
||||
shapeNew.push_back(arguments->at(e));
|
||||
}
|
||||
|
||||
return ND4J_STATUS_BAD_INPUT;
|
||||
}
|
||||
|
||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
||||
|
||||
DECLARE_TYPES(reshape) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setSameMode(true);
|
||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
nd4j_printv("Reshape: new shape", shapeNew);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(reshape) {
|
||||
auto inp = inputShape->at(0);
|
||||
auto xr = x->reshape(order, shapeNew);
|
||||
z->assign(xr);
|
||||
STORE_RESULT(*z);
|
||||
|
||||
// we can launch op using Int arguments
|
||||
if (inputShape->size() == 1) {
|
||||
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
|
||||
std::vector<int> *arguments = block.getIArguments();
|
||||
return Status::OK();
|
||||
|
||||
int e = 1;
|
||||
char order = (char) -(*arguments)[0];
|
||||
if (order != 'c' && order != 'f') {
|
||||
order = shape::order(inp);
|
||||
e = 0;
|
||||
} else if (block.width() == 2) {
|
||||
|
||||
auto s = INPUT_VARIABLE(1);
|
||||
|
||||
char order = 'c';
|
||||
if (block.numI() > 0)
|
||||
order = (char) -INT_ARG(0);
|
||||
|
||||
std::vector<Nd4jLong> shapeNew(s->lengthOf());
|
||||
|
||||
for (int e = 0; e < (int) s->lengthOf(); e++) {
|
||||
auto dim = s->e<Nd4jLong >(e);
|
||||
if (dim == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(int e2 = 0; e2 < e; e2++){
|
||||
shapeLength *= s->e<Nd4jLong>(e2);
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
|
||||
int e2 = e;
|
||||
for (; e < (int) arguments->size(); e++) {
|
||||
if ((int) arguments->at(e) == -1){
|
||||
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(; e2 < e; e2 ++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
||||
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
|
||||
if(shapeLength == 0){
|
||||
//Edge case for empty:
|
||||
shapeNew.push_back(0);
|
||||
} else {
|
||||
//Standard case
|
||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
||||
shapeNew.push_back(realShape);
|
||||
}
|
||||
}
|
||||
else{
|
||||
shapeNew.push_back(arguments->at(e));
|
||||
}
|
||||
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
||||
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= s->e<Nd4jLong>(e2);
|
||||
}
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
|
||||
} else {
|
||||
// or, with second input "as shape"
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
|
||||
// special case here
|
||||
if (y->isEmpty()) {
|
||||
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
|
||||
}
|
||||
//Special case: empty.reshape(-1) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
||||
Nd4jLong prod = 1;
|
||||
bool hasNegs = false;
|
||||
for (auto v:shapeOf) {
|
||||
if (v < 0) {
|
||||
hasNegs = true;
|
||||
v = 0;
|
||||
}
|
||||
|
||||
prod *= v;
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
||||
|
||||
// if there are -1s - we turn them into zeros
|
||||
if (hasNegs) {
|
||||
for (int e = 0; e < shapeOf.size(); e++)
|
||||
if (shapeOf[e] < 0)
|
||||
shapeOf[e] = 0;
|
||||
}
|
||||
|
||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
||||
return SHAPELIST(CONSTANT(newShape));
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
||||
|
||||
for (int e = 0; e < (int) y->lengthOf(); e++) {
|
||||
auto dim = y->e<Nd4jLong>(e);
|
||||
if (dim == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(int e2 = 0; e2 < e; e2++){
|
||||
shapeLength *= y->e<Nd4jLong>(e2);
|
||||
}
|
||||
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
|
||||
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= y->e<Nd4jLong>(e2);
|
||||
}
|
||||
|
||||
if(shapeLength == 0){
|
||||
//Edge case for empty:
|
||||
shapeNew[e] = 0;
|
||||
} else {
|
||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
||||
shapeNew[e] = realShape;
|
||||
}
|
||||
}else {
|
||||
shapeNew[e] = dim;
|
||||
}
|
||||
}
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
||||
shapeNew[e] = realShape;
|
||||
}
|
||||
else{
|
||||
shapeNew[e] = dim;
|
||||
}
|
||||
}
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
nd4j_printv("Reshape: new shape", shapeNew);
|
||||
}
|
||||
|
||||
if (s->isEmpty()) {
|
||||
// just a scalar
|
||||
z->assign(x);
|
||||
} else {
|
||||
auto xr = x->reshape(order, shapeNew);
|
||||
z->assign(xr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
||||
}
|
||||
|
||||
return ND4J_STATUS_BAD_INPUT;
|
||||
}
|
||||
|
||||
|
||||
DECLARE_TYPES(reshape) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(reshape) {
|
||||
auto inp = inputShape->at(0);
|
||||
|
||||
// we can launch op using Int arguments
|
||||
if (inputShape->size() == 1) {
|
||||
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
|
||||
std::vector<int> *arguments = block.getIArguments();
|
||||
|
||||
int e = 1;
|
||||
char order = (char) -(*arguments)[0];
|
||||
if (order != 'c' && order != 'f') {
|
||||
order = shape::order(inp);
|
||||
e = 0;
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
|
||||
int e2 = e;
|
||||
for (; e < (int) arguments->size(); e++) {
|
||||
if ((int) arguments->at(e) == -1){
|
||||
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(; e2 < e; e2 ++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
||||
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
|
||||
if(shapeLength == 0){
|
||||
//Edge case for empty:
|
||||
shapeNew.push_back(0);
|
||||
} else {
|
||||
//Standard case
|
||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
||||
shapeNew.push_back(realShape);
|
||||
}
|
||||
}
|
||||
else{
|
||||
shapeNew.push_back(arguments->at(e));
|
||||
}
|
||||
}
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
|
||||
} else {
|
||||
// or, with second input "as shape"
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
|
||||
// special case here
|
||||
if (y->isEmpty()) {
|
||||
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
|
||||
}
|
||||
//Special case: empty.reshape(-1) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
||||
Nd4jLong prod = 1;
|
||||
bool hasNegs = false;
|
||||
for (auto v:shapeOf) {
|
||||
if (v < 0) {
|
||||
hasNegs = true;
|
||||
v = 0;
|
||||
}
|
||||
|
||||
prod *= v;
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
||||
|
||||
// if there are -1s - we turn them into zeros
|
||||
if (hasNegs) {
|
||||
for (int e = 0; e < shapeOf.size(); e++)
|
||||
if (shapeOf[e] < 0)
|
||||
shapeOf[e] = 0;
|
||||
}
|
||||
|
||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
||||
return SHAPELIST(CONSTANT(newShape));
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
||||
|
||||
for (int e = 0; e < (int) y->lengthOf(); e++) {
|
||||
auto dim = y->e<Nd4jLong>(e);
|
||||
if (dim == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(int e2 = 0; e2 < e; e2++){
|
||||
shapeLength *= y->e<Nd4jLong>(e2);
|
||||
}
|
||||
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
|
||||
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= y->e<Nd4jLong>(e2);
|
||||
}
|
||||
|
||||
if(shapeLength == 0){
|
||||
//Edge case for empty:
|
||||
shapeNew[e] = 0;
|
||||
} else {
|
||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
||||
shapeNew[e] = realShape;
|
||||
}
|
||||
}else {
|
||||
shapeNew[e] = dim;
|
||||
}
|
||||
}
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -28,35 +28,27 @@ namespace nd4j {
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(reshapeas, 2, 1, true, 0, 0) {
|
||||
|
||||
CUSTOM_OP_IMPL(reshapeas, 2, 1, false, 0, 0) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
std::vector<Nd4jLong> shapeNew(y->shapeOf(), y->shapeOf() + y->rankOf());
|
||||
char order = y->ordering();
|
||||
|
||||
if (x->reshapei(order, shapeNew)) {
|
||||
*z = *x;
|
||||
STORE_RESULT(*z);
|
||||
if (x->reshapei(y->ordering(), y->getShapeAsVector())) {
|
||||
|
||||
z->assign(x);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return ND4J_STATUS_BAD_INPUT;
|
||||
}
|
||||
DECLARE_SYN(reshape_as, reshapeas);
|
||||
|
||||
DECLARE_SHAPE_FN(reshapeas) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(1);
|
||||
int shapeInfoLength = inputShapeInfo[0]*2 + 4;
|
||||
|
||||
Nd4jLong* outputShapeInfo(nullptr);
|
||||
COPY_SHAPE(inputShapeInfo, outputShapeInfo);
|
||||
|
||||
return SHAPELIST(CONSTANT(outputShapeInfo));
|
||||
}
|
||||
DECLARE_SHAPE_FN(reshapeas) {
|
||||
|
||||
return SHAPELIST(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->getShapeInfo(), false, block.workspace()));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(reshapeas) {
|
||||
getOpDescriptor()
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(squeeze, 1, 1, true, 0, -2) {
|
||||
CUSTOM_OP_IMPL(squeeze, 1, 1, false, 0, -2) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
|
@ -36,14 +36,14 @@ namespace nd4j {
|
|||
int _a = INT_ARG(e);
|
||||
if (_a < 0)
|
||||
_a += input->rankOf();
|
||||
|
||||
|
||||
axis.emplace_back(_a);
|
||||
}
|
||||
else if (block.width() > 1) {
|
||||
auto a = INPUT_VARIABLE(1);
|
||||
for (Nd4jLong e = 0; e < a->lengthOf(); e++) {
|
||||
int _a = a->e<int>(e);
|
||||
|
||||
|
||||
if (_a < 0)
|
||||
_a += input->rankOf();
|
||||
|
||||
|
@ -71,10 +71,14 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (block.isInplace()) {
|
||||
output->reshapei(input->ordering(), shape);
|
||||
output->reshapei(input->ordering(), shape, false);
|
||||
} else {
|
||||
auto tmp = input->reshape(input->ordering(), shape);
|
||||
output->assign(tmp);
|
||||
if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) {
|
||||
output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset());
|
||||
} else {
|
||||
auto tmp = input->reshape(input->ordering(), shape);
|
||||
output->assign(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -106,20 +110,20 @@ namespace nd4j {
|
|||
int _a = INT_ARG(e);
|
||||
if (_a < 0)
|
||||
_a += rank;
|
||||
|
||||
|
||||
axis.emplace_back(_a);
|
||||
}
|
||||
else if (block.width() > 1) {
|
||||
auto a = INPUT_VARIABLE(1);
|
||||
for (int e = 0; e < a->lengthOf(); e++) {
|
||||
int _a = a->e<int>(e);
|
||||
|
||||
|
||||
if (_a < 0)
|
||||
_a += rank;
|
||||
|
||||
axis.emplace_back(_a);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
auto order = shape::order(in);
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(tile_to_shape, 1, 1, true, 0, -1) {
|
||||
CUSTOM_OP_IMPL(tile_to_shape, 1, 1, false, 0, -1) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver119 on 29/10/17.
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
|
@ -25,113 +26,52 @@
|
|||
#include <helpers/ShapeUtils.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace ops {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(transpose, 1, 1, true, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
if (block.width() == 1) {
|
||||
if (block.isInplace()) {
|
||||
x->transposei();
|
||||
STORE_RESULT(*x);
|
||||
} else {
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
auto t = x->transpose();
|
||||
output->assign(t);
|
||||
STORE_RESULT(*output);
|
||||
}
|
||||
} else {
|
||||
// this is tf-mode transpose, that's nd4j permute
|
||||
bool replace = false;
|
||||
std::vector<int> arguments(*block.getIArguments());
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) {
|
||||
|
||||
auto w = block.width();
|
||||
auto a = arguments.size();
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
if (w == 2 && a == 0) {
|
||||
auto axis = INPUT_VARIABLE(1);
|
||||
for (int e = 0; e < axis->lengthOf(); e++) {
|
||||
auto ax = axis->e<int>(e);
|
||||
if (ax < 0)
|
||||
ax += x->rankOf();
|
||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(z->isEmpty(), 0, "TRANSPOSE OP: when input is empty, output must also be empty");
|
||||
return Status::OK(); //No op
|
||||
}
|
||||
|
||||
arguments.emplace_back(ax);
|
||||
}
|
||||
|
||||
replace = true;
|
||||
} else if (a == 0) {
|
||||
for (int e = x->rankOf() - 1; e >= 0; e--)
|
||||
arguments.emplace_back(e);
|
||||
}
|
||||
|
||||
// 0D edge case
|
||||
if (x->rankOf() == 0) {
|
||||
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
if (!block.isInplace())
|
||||
output->assign(x);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if(block.isInplace()) { // in-place
|
||||
x->permutei(arguments);
|
||||
STORE_RESULT(x);
|
||||
} else {
|
||||
auto input = x->permute(arguments);
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
output->assign(input);
|
||||
}
|
||||
}
|
||||
if (block.width() == 1 && block.getIArguments()->size() == 0) {
|
||||
z->assign(x->transpose());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(transpose) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
}
|
||||
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
DECLARE_SHAPE_FN(transpose) {
|
||||
if (block.width() == 1) {
|
||||
auto outputShapeInfo = ShapeUtils::evalTranspShapeInfo(*INPUT_VARIABLE(0), block.workspace());
|
||||
return SHAPELIST(outputShapeInfo);
|
||||
} else {
|
||||
// this is basically permute mode
|
||||
auto shapeList = SHAPELIST();
|
||||
auto arguments = block.getIArguments();
|
||||
if (shape::rank(inputShape->at(0)) == 0) {
|
||||
Nd4jLong *newshape;
|
||||
ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape->at(0)), Nd4jLong);
|
||||
newshape[0] = 0;
|
||||
newshape[1] = 0;
|
||||
newshape[2] = 1;
|
||||
newshape[3] = 99;
|
||||
ArrayOptions::copyDataType(newshape, inputShape->at(0));
|
||||
shapeList->push_back(newshape);
|
||||
} else if (arguments->size() > 0 || inputShape->size() > 1) {
|
||||
auto axis = arguments->size() > 0 ? *arguments : (INPUT_VARIABLE(1))->template asVectorT<int>();
|
||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(axis.data(), axis.size(), *INPUT_VARIABLE(0), block.workspace());
|
||||
shapeList->push_back(outputShapeInfo);
|
||||
} else if (inputShape->size() == 2) {
|
||||
// dead end
|
||||
auto axis = INPUT_VARIABLE(1);
|
||||
auto axisV = axis->template asVectorT<Nd4jLong>();
|
||||
auto newshape = ShapeUtils::evalPermShapeInfo(axisV.data(), axisV.size(), *INPUT_VARIABLE(0), block.workspace());
|
||||
shapeList->push_back(newshape);
|
||||
} else {
|
||||
int rank = shape::rank(inputShape->at(0));
|
||||
for (int e = rank - 1; e >= 0; e--)
|
||||
arguments->emplace_back(e);
|
||||
z->assign(x->permute(permutationVector));
|
||||
|
||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace());
|
||||
shapeList->push_back(outputShapeInfo);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(transpose) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(transpose) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
|
||||
if (block.width() == 1 && block.getIArguments()->size() == 0)
|
||||
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
|
||||
|
||||
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
|
||||
|
||||
return SHAPELIST(outputShapeInfo);
|
||||
}
|
||||
|
||||
return shapeList;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue