Merge pull request #8723 from KonduitAI/master

Merge recent development updates
master
Alex Black 2020-02-21 20:00:35 +11:00 committed by GitHub
commit e4ddf109c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
239 changed files with 7636 additions and 4122 deletions

View File

@ -1,47 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.util;
import org.slf4j.Logger;
import java.awt.*;
import java.net.URI;
/**
* Various utilities for webpages and dealing with browsers
*/
public class WebUtils {
public static void tryOpenBrowser(String path, Logger log) {
try {
WebUtils.openBrowser(new URI(path));
} catch (Exception e) {
log.error("Could not open browser", e);
System.out.println("Browser could not be launched automatically.\nUI path: " + path);
}
}
public static void openBrowser(URI uri) throws Exception {
if (Desktop.isDesktopSupported()) {
Desktop.getDesktop().browse(uri);
} else {
throw new UnsupportedOperationException(
"Cannot open browser on this platform: Desktop.isDesktopSupported() == false");
}
}
}

View File

@ -127,7 +127,7 @@ public class BraninFunction {
BraninConfig candidate = (BraninConfig) c.getValue();
double score = scoreFunction.score(candidate, null, (Map) null);
System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
// System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
Thread.sleep(20);

View File

@ -54,7 +54,7 @@ public class TestRandomSearch extends BaseDL4JTest {
runner.execute();
System.out.println("----- Complete -----");
// System.out.println("----- Complete -----");
}

View File

@ -16,8 +16,8 @@
package org.deeplearning4j.arbiter.optimize.genetic;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.math3.random.RandomGenerator;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class TestRandomGenerator implements RandomGenerator {
private final int[] intRandomNumbers;
@ -63,17 +63,17 @@ public class TestRandomGenerator implements RandomGenerator {
@Override
public long nextLong() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
@Override
public boolean nextBoolean() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
@Override
public float nextFloat() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
@Override
@ -83,6 +83,6 @@ public class TestRandomGenerator implements RandomGenerator {
@Override
public double nextGaussian() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
@ -26,7 +27,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
@ -42,7 +42,7 @@ public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
@Override
public CrossoverResult crossover() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.culling;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.population.Populati
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import java.util.List;
@ -46,7 +46,7 @@ public class RatioCullOperatorTests extends BaseDL4JTest {
@Override
public void cullPopulation() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
public double getCullRatio() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
@ -33,7 +34,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import static org.junit.Assert.assertArrayEquals;
@ -55,7 +55,7 @@ public class GeneticSelectionOperatorTests extends BaseDL4JTest {
@Override
public void cullPopulation() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
@Override

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.Selection
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class SelectionOperatorTests extends BaseDL4JTest {
private class TestSelectionOperator extends SelectionOperator {
@ -39,7 +39,7 @@ public class SelectionOperatorTests extends BaseDL4JTest {
@Override
public double[] buildNextGenes() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
}

View File

@ -158,7 +158,7 @@ public class TestComputationGraphSpace extends BaseDL4JTest {
}
}
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
assertTrue(reluCount > 0);
assertTrue(tanhCount > 0);

View File

@ -162,7 +162,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
List<ResultReference> results = runner.getResults();
assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
}
}

View File

@ -165,7 +165,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
List<ResultReference> results = runner.getResults();
assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
}
}

View File

@ -101,7 +101,7 @@ public class TestLayerSpace extends BaseDL4JTest {
double l2 = TestUtils.getL2(l);
IActivation activation = l.getActivationFn();
System.out.println(lr + "\t" + l2 + "\t" + activation);
// System.out.println(lr + "\t" + l2 + "\t" + activation);
assertTrue(lr >= 0.3 && lr <= 0.4);
assertTrue(l2 >= 0.01 && l2 <= 0.1);
@ -190,7 +190,7 @@ public class TestLayerSpace extends BaseDL4JTest {
ActivationLayer al = als.getValue(d);
IActivation activation = al.getActivationFn();
System.out.println(activation);
// System.out.println(activation);
assertTrue(containsActivationFunction(actFns, activation));
}
@ -228,7 +228,7 @@ public class TestLayerSpace extends BaseDL4JTest {
IActivation activation = el.getActivationFn();
long nOut = el.getNOut();
System.out.println(activation + "\t" + nOut);
// System.out.println(activation + "\t" + nOut);
assertTrue(containsActivationFunction(actFns, activation));
assertTrue(nOut >= 10 && nOut <= 20);
@ -295,7 +295,7 @@ public class TestLayerSpace extends BaseDL4JTest {
long nOut = el.getNOut();
double forgetGate = el.getForgetGateBiasInit();
System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
// System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
assertTrue(containsActivationFunction(actFns, activation));
assertTrue(nOut >= 10 && nOut <= 20);

View File

@ -293,8 +293,8 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly
}
System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
// System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
}

View File

@ -98,7 +98,8 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson()));
FileUtils.writeStringToFile(new File(configPath),configuration.toJson());
System.out.println(configuration.toJson());
// System.out.println(configuration.toJson());
configuration.toJson();
log.info("Starting test");
cliRunner.runMain(

View File

@ -0,0 +1,390 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.deeplearning4j.regressiontest;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.impl.MergeVertex;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer;
import org.junit.Test;
import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.lossfunctions.impl.LossMAE;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.resources.Resources;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import static org.junit.Assert.*;
public class RegressionTest100b6 extends BaseDL4JTest {
@Override
public DataType getDataType() {
return DataType.FLOAT;
}
@Test
public void testCustomLayer() throws Exception {
for (DataType dtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
String dtypeName = dtype.toString().toLowerCase();
File f = Resources.asFile("regression_testing/100b6/CustomLayerExample_100b6_" + dtypeName + ".bin");
MultiLayerNetwork.load(f, true);
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
// net = net.clone();
DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer();
assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
assertEquals(new RmsProp(0.95), l0.getIUpdater());
CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer();
assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
assertEquals(new RmsProp(0.95), l1.getIUpdater());
INDArray outExp;
File f2 = Resources
.asFile("regression_testing/100b6/CustomLayerExample_Output_100b6_" + dtypeName + ".bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/CustomLayerExample_Input_100b6_" + dtypeName + ".bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
assertEquals(dtype, in.dataType());
assertEquals(dtype, outExp.dataType());
assertEquals(dtype, net.params().dataType());
assertEquals(dtype, net.getFlattenedGradients().dataType());
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
//System.out.println(Arrays.toString(net.params().data().asFloat()));
INDArray outAct = net.output(in);
assertEquals(dtype, outAct.dataType());
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
assertEquals(dtype, net.params().dataType());
boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue(outExp + " vs " + outAct, eq); }
}
@Test
public void testLSTM() throws Exception {
File f = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_100b6.bin");
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
LSTM l0 = (LSTM) net.getLayer(0).conf().getLayer();
assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater());
LSTM l1 = (LSTM) net.getLayer(1).conf().getLayer();
assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).conf().getLayer();
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
assertEquals(new Adam(0.005), l2.getIUpdater());
assertEquals(BackpropType.TruncatedBPTT, net.getLayerWiseConfigurations().getBackpropType());
assertEquals(50, net.getLayerWiseConfigurations().getTbpttBackLength());
assertEquals(50, net.getLayerWiseConfigurations().getTbpttFwdLength());
INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Output_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Input_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
INDArray outAct = net.output(in);
assertEquals(outExp, outAct);
}
@Test
public void testVae() throws Exception {
File f = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_100b6.bin");
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).conf().getLayer();
assertEquals(new ActivationLReLU(), l0.getActivationFn());
assertEquals(32, l0.getNOut());
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater());
INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Output_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Input_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
INDArray outAct = net.output(in);
assertEquals(outExp, outAct);
}
@Test
public void testYoloHouseNumber() throws Exception {
File f = Resources.asFile("regression_testing/100b6/HouseNumberDetection_100b6.bin");
ComputationGraph net = ComputationGraph.load(f, true);
int nBoxes = 5;
int nClasses = 10;
ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getConfiguration().getVertices()
.get("convolution2d_9")).getLayerConf().getLayer();
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
assertEquals(new ActivationIdentity(), cl.getActivationFn());
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/HouseNumberDetection_Output_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/HouseNumberDetection_Input_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
INDArray outAct = net.outputSingle(in);
boolean eq = outExp.equalsWithEps(outAct.castTo(outExp.dataType()), 1e-3);
assertTrue(eq);
}
@Test
public void testSyntheticCNN() throws Exception {
File f = Resources.asFile("regression_testing/100b6/SyntheticCNN_100b6.bin");
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).conf().getLayer();
assertEquals(new ActivationReLU(), l0.getActivationFn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
assertArrayEquals(new int[]{2, 1}, l0.getStride());
assertArrayEquals(new int[]{1, 1}, l0.getDilation());
assertArrayEquals(new int[]{0, 0}, l0.getPadding());
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).conf().getLayer();
assertEquals(new ActivationReLU(), l1.getActivationFn());
assertEquals(8, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l1.getStride());
assertArrayEquals(new int[]{1, 1}, l1.getDilation());
assertArrayEquals(new int[]{0, 0}, l1.getPadding());
assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
assertEquals(1, l1.getDepthMultiplier());
SubsamplingLayer l2 = (SubsamplingLayer) net.getLayer(2).conf().getLayer();
assertArrayEquals(new int[]{3, 3}, l2.getKernelSize());
assertArrayEquals(new int[]{2, 2}, l2.getStride());
assertArrayEquals(new int[]{1, 1}, l2.getDilation());
assertArrayEquals(new int[]{0, 0}, l2.getPadding());
assertEquals(PoolingType.MAX, l2.getPoolingType());
ZeroPaddingLayer l3 = (ZeroPaddingLayer) net.getLayer(3).conf().getLayer();
assertArrayEquals(new int[]{4, 4, 4, 4}, l3.getPadding());
Upsampling2D l4 = (Upsampling2D) net.getLayer(4).conf().getLayer();
assertArrayEquals(new int[]{3, 3}, l4.getSize());
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).conf().getLayer();
assertEquals(new ActivationReLU(), l5.getActivationFn());
assertEquals(16, l5.getNOut());
assertEquals(new WeightInitXavier(), l5.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
assertEquals(new Adam(0.005), l5.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l5.getStride());
assertArrayEquals(new int[]{1, 1}, l5.getDilation());
assertArrayEquals(new int[]{0, 0}, l5.getPadding());
assertEquals(2, l5.getDepthMultiplier());
SubsamplingLayer l6 = (SubsamplingLayer) net.getLayer(6).conf().getLayer();
assertArrayEquals(new int[]{2, 2}, l6.getKernelSize());
assertArrayEquals(new int[]{2, 2}, l6.getStride());
assertArrayEquals(new int[]{1, 1}, l6.getDilation());
assertArrayEquals(new int[]{0, 0}, l6.getPadding());
assertEquals(PoolingType.MAX, l6.getPoolingType());
Cropping2D l7 = (Cropping2D) net.getLayer(7).conf().getLayer();
assertArrayEquals(new int[]{3, 3, 2, 2}, l7.getCropping());
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).conf().getLayer();
assertEquals(4, l8.getNOut());
assertEquals(new WeightInitXavier(), l8.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
assertEquals(new Adam(0.005), l8.getIUpdater());
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l8.getStride());
assertArrayEquals(new int[]{1, 1}, l8.getDilation());
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).conf().getLayer();
assertEquals(new WeightInitXavier(), l9.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
assertEquals(new Adam(0.005), l9.getIUpdater());
assertEquals(new LossMAE(), l9.getLossFn());
INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/SyntheticCNN_Output_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/SyntheticCNN_Input_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
INDArray outAct = net.output(in);
//19 layers - CPU vs. GPU difference accumulates notably, but appears to be correct
if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")){
assertEquals(outExp, outAct);
} else {
boolean eq = outExp.equalsWithEps(outAct, 0.1);
assertTrue(eq);
}
}
@Test
public void testSyntheticBidirectionalRNNGraph() throws Exception {
File f = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_100b6.bin");
ComputationGraph net = ComputationGraph.load(f, true);
Bidirectional l0 = (Bidirectional) net.getLayer("rnn1").conf().getLayer();
LSTM l1 = (LSTM) l0.getFwd();
assertEquals(16, l1.getNOut());
assertEquals(new ActivationReLU(), l1.getActivationFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
LSTM l2 = (LSTM) l0.getBwd();
assertEquals(16, l2.getNOut());
assertEquals(new ActivationReLU(), l2.getActivationFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
Bidirectional l3 = (Bidirectional) net.getLayer("rnn2").conf().getLayer();
SimpleRnn l4 = (SimpleRnn) l3.getFwd();
assertEquals(16, l4.getNOut());
assertEquals(new ActivationReLU(), l4.getActivationFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l4));
SimpleRnn l5 = (SimpleRnn) l3.getBwd();
assertEquals(16, l5.getNOut());
assertEquals(new ActivationReLU(), l5.getActivationFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
MergeVertex mv = (MergeVertex) net.getVertex("concat");
GlobalPoolingLayer gpl = (GlobalPoolingLayer) net.getLayer("pooling").conf().getLayer();
assertEquals(PoolingType.MAX, gpl.getPoolingType());
assertArrayEquals(new int[]{2}, gpl.getPoolingDimensions());
assertTrue(gpl.isCollapseDimensions());
OutputLayer outl = (OutputLayer) net.getLayer("out").conf().getLayer();
assertEquals(3, outl.getNOut());
assertEquals(new LossMCXENT(), outl.getLossFn());
INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_Output_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_Input_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
INDArray outAct = net.output(in)[0];
assertEquals(outExp, outAct);
}
}

View File

@ -41,7 +41,7 @@ public class TestGraphLoading extends BaseDL4JTest {
IGraph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
System.out.println(graph);
// System.out.println(graph);
assertEquals(graph.numVertices(), 7);
int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}};
@ -66,7 +66,7 @@ public class TestGraphLoading extends BaseDL4JTest {
edgeLineProcessor, vertexFactory, 10, false);
System.out.println(graph);
// System.out.println(graph);
for (int i = 0; i < 10; i++) {
List<Edge<String>> edges = graph.getEdgesOut(i);
@ -111,7 +111,7 @@ public class TestGraphLoading extends BaseDL4JTest {
Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(),
edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
System.out.println(graph);
// System.out.println(graph);
for (int i = 0; i < 10; i++) {
List<Edge<String>> edges = graph.getEdgesOut(i);

View File

@ -71,7 +71,7 @@ public class TestGraphLoadingWeighted extends BaseDL4JTest {
}
}
System.out.println(graph);
// System.out.println(graph);
}

View File

@ -220,7 +220,7 @@ public class TestGraph extends BaseDL4JTest {
sum += transitionProb[i][j];
for (int j = 0; j < transitionProb[i].length; j++)
transitionProb[i][j] /= sum;
System.out.println(Arrays.toString(transitionProb[i]));
// System.out.println(Arrays.toString(transitionProb[i]));
}
//Check that transition probs are essentially correct (within bounds of random variation)

View File

@ -145,8 +145,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
fail(msg);
else
System.out.println(msg);
// else
// System.out.println(msg);
}
}
@ -333,10 +333,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
fail(msg);
else
System.out.println(msg);
// else
// System.out.println(msg);
}
System.out.println();
// System.out.println();
}
}

View File

@ -67,7 +67,7 @@ public class TestDeepWalk extends BaseDL4JTest {
for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new long[] {vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
}
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
@ -77,11 +77,11 @@ public class TestDeepWalk extends BaseDL4JTest {
for (int t = 0; t < 5; t++) {
iter.reset();
deepWalk.fit(iter);
System.out.println("--------------------");
// System.out.println("--------------------");
for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new long[] {vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
}
}
}
@ -160,7 +160,7 @@ public class TestDeepWalk extends BaseDL4JTest {
continue;
double sim = deepWalk.similarity(i, nearestTo);
System.out.println(i + "\t" + nearestTo + "\t" + sim);
// System.out.println(i + "\t" + nearestTo + "\t" + sim);
assertTrue(sim <= minSimNearest);
}
}
@ -211,7 +211,7 @@ public class TestDeepWalk extends BaseDL4JTest {
Graph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
System.out.println(graph);
// System.out.println(graph);
Nd4j.getRandom().setSeed(12345);
@ -229,11 +229,13 @@ public class TestDeepWalk extends BaseDL4JTest {
//Calculate similarity(0,i)
for (int i = 0; i < nVertices; i++) {
System.out.println(deepWalk.similarity(0, i));
// System.out.println(deepWalk.similarity(0, i));
deepWalk.similarity(0, i);
}
for (int i = 0; i < nVertices; i++)
System.out.println(deepWalk.getVertexVector(i));
// System.out.println(deepWalk.getVertexVector(i));
deepWalk.getVertexVector(i);
}
@Test(timeout = 60000L)

View File

@ -38,9 +38,11 @@ public class TestGraphHuffman extends BaseDL4JTest {
gh.buildTree(vertexDegrees);
for (int i = 0; i < 7; i++)
System.out.println(i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i)));
for (int i = 0; i < 7; i++) {
String s = i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i));
// System.out.println(s);
}
int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5};
for (int i = 0; i < vertexDegrees.length; i++) {

View File

@ -3,6 +3,7 @@ package org.deeplearning4j.util;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -121,7 +122,7 @@ public class DL4JModelValidator {
}
try{
MultiLayerConfiguration.fromJson(config);
ComputationGraphConfiguration.fromJson(config);
} catch (Throwable t){
return ValidationResult.builder()
.formatType("ComputationGraph")

View File

@ -79,8 +79,9 @@ public class ParameterServerParallelWrapperTest extends BaseDL4JTest {
model.init();
ParallelWrapper parameterServerParallelWrapper =
new ParallelWrapper.Builder(model).trainerFactory(new ParameterServerTrainerContext())
.workers(Runtime.getRuntime().availableProcessors())
new ParallelWrapper.Builder(model)
.workers(Math.min(4, Runtime.getRuntime().availableProcessors()))
.trainerFactory(new ParameterServerTrainerContext())
.reportScoreAfterAveraging(true).prefetchBuffer(3).build();
parameterServerParallelWrapper.fit(mnistTrain);

View File

@ -104,7 +104,7 @@ public class SparkWord2VecTest extends BaseDL4JTest {
public void call(ExportContainer<VocabWord> v) throws Exception {
assertNotNull(v.getElement());
assertNotNull(v.getArray());
System.out.println(v.getElement() + " - " + v.getArray());
// System.out.println(v.getElement() + " - " + v.getArray());
}
}
}

View File

@ -66,7 +66,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -119,7 +119,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
@ -155,7 +155,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -198,7 +198,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -231,7 +231,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();

View File

@ -69,7 +69,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -120,7 +120,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MSE).build(), "in")
.setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
@ -158,7 +158,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -203,7 +203,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -238,7 +238,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();

View File

@ -59,7 +59,7 @@ public class TestShuffleExamples extends BaseSparkTest {
int totalExampleCount = 0;
for (DataSet ds : shuffledList) {
totalExampleCount += ds.getFeatures().length();
System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
// System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
assertEquals(ds.getFeatures(), ds.getLabels());
}

View File

@ -86,7 +86,7 @@ public class TestExport extends BaseSparkTest {
for (File file : files) {
if (!file.getPath().endsWith(".bin"))
continue;
System.out.println(file);
// System.out.println(file);
DataSet ds = new DataSet();
ds.load(file);
assertEquals(minibatchSize, ds.numExamples());
@ -144,7 +144,7 @@ public class TestExport extends BaseSparkTest {
for (File file : files) {
if (!file.getPath().endsWith(".bin"))
continue;
System.out.println(file);
// System.out.println(file);
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
ds.load(file);
assertEquals(minibatchSize, ds.getFeatures(0).size(0));

View File

@ -92,9 +92,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
int[][] colorCountsByPartition = new int[3][2];
for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) {
System.out.println(val);
// System.out.println(val);
Integer partition = hbp.getPartition(val._1());
System.out.println(partition);
// System.out.println(partition);
if (val._2().equals("red"))
colorCountsByPartition[partition][0] += 1;
@ -102,9 +102,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
colorCountsByPartition[partition][1] += 1;
}
for (int i = 0; i < 3; i++) {
System.out.println(Arrays.toString(colorCountsByPartition[i]));
}
// for (int i = 0; i < 3; i++) {
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
// }
for (int i = 0; i < 3; i++) {
// avg red per partition : 2.33
assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4);
@ -178,12 +178,12 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
colorCountsByPartition[partition][1] += 1;
}
for (int i = 0; i < numPartitions; i++) {
System.out.println(Arrays.toString(colorCountsByPartition[i]));
}
System.out.println("Ideal red # per partition: " + avgRed);
System.out.println("Ideal blue # per partition: " + avgBlue);
// for (int i = 0; i < numPartitions; i++) {
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
// }
//
// System.out.println("Ideal red # per partition: " + avgRed);
// System.out.println("Ideal blue # per partition: " + avgBlue);
for (int i = 0; i < numPartitions; i++) {
// avg red per partition : 2.33

View File

@ -115,7 +115,7 @@ public class TestSparkComputationGraph extends BaseSparkTest {
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(1)));
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5)));
JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
scg.fitMultiDataSet(rdd);

View File

@ -31,8 +31,11 @@ import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
@ -45,8 +48,24 @@ import static org.junit.Assert.assertTrue;
@Slf4j
public class TestSparkDl4jMultiLayer extends BaseSparkTest {
@Test(timeout = 120000L)
@Override
public long getTimeoutMilliseconds() {
return 120000L;
}
@Override
public DataType getDataType() {
return DataType.FLOAT;
}
@Override
public DataType getDefaultFPDataType() {
return DataType.FLOAT;
}
@Test
public void testEvaluationSimple() throws Exception {
Nd4j.getRandom().setSeed(12345);
for( int evalWorkers : new int[]{1, 4, 8}) {
//Simple test to validate DL4J issue 4099 is fixed...
@ -75,18 +94,18 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
//----------------------------------
//Create network configuration and conduct network training
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.activation(Activation.LEAKYRELU)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.02, 0.9))
.l2(1e-4)
.updater(new Adam(1e-3))
.l2(1e-5)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
.build();
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options

View File

@ -333,15 +333,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd);
}
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
// System.out.println("Initial (Spark) params: "
// + Arrays.toString(initialSparkParams.data().asFloat()));
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertEquals(initialParams, initialSparkParams);
assertNotEquals(initialParams, finalParams);
assertEquals(finalParams, finalSparkParams);
@ -405,15 +406,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd);
}
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
// System.out.println("Initial (Spark) params: "
// + Arrays.toString(initialSparkParams.data().asFloat()));
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
@ -478,18 +480,19 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd);
}
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
// executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
float[] fp = finalParams.data().asFloat();
float[] fps = finalSparkParams.data().asFloat();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(fp));
System.out.println("Final (Spark) params: " + Arrays.toString(fps));
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
// System.out.println("Initial (Spark) params: "
// + Arrays.toString(initialSparkParams.data().asFloat()));
// System.out.println("Final (Local) params: " + Arrays.toString(fp));
// System.out.println("Final (Spark) params: " + Arrays.toString(fps));
assertEquals(initialParams, initialSparkParams);
assertNotEquals(initialParams, finalParams);
@ -551,14 +554,15 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd);
}
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
// System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);

View File

@ -37,7 +37,7 @@ public class TestJsonYaml {
String json = tm.toJson();
String yaml = tm.toYaml();
System.out.println(json);
// System.out.println(json);
TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json);
TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml);

View File

@ -389,7 +389,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs");
for (EventStats e : workerFitStats) {
ExampleCountEventStats eces = (ExampleCountEventStats) e;
System.out.println(eces.getTotalExampleCount());
// System.out.println(eces.getTotalExampleCount());
}
for (EventStats e : workerFitStats) {
@ -457,7 +457,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString());
// System.out.println(stats.statsAsString());
stats.statsAsString();
sparkNet.getTrainingMaster().deleteTempFiles(sc);
}
@ -483,7 +484,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
i++;
}
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
@ -527,7 +528,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
//Expect
System.out.println(stats.statsAsString());
// System.out.println(stats.statsAsString());
stats.statsAsString();
assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size());
List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs");
@ -566,8 +568,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
i++;
}
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
// System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
@ -610,7 +612,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString());
// System.out.println(stats.statsAsString());
stats.statsAsString();
//Same thing, buf for MultiDataSet objects:
config = new Configuration();
@ -631,7 +634,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter);
stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString());
// System.out.println(stats.statsAsString());
stats.statsAsString();
}
@ -730,13 +734,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.build();
for (int avgFreq : new int[] {1, 5, 10}) {
System.out.println("--- Avg freq " + avgFreq + " ---");
// System.out.println("--- Avg freq " + avgFreq + " ---");
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(),
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
.repartionData(Repartition.Always).build());
sparkNet.setListeners(new ScoreIterationListener(1));
sparkNet.setListeners(new ScoreIterationListener(5));
@ -778,13 +782,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.setOutputs("1").build();
for (int avgFreq : new int[] {1, 5, 10}) {
System.out.println("--- Avg freq " + avgFreq + " ---");
// System.out.println("--- Avg freq " + avgFreq + " ---");
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(),
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
.repartionData(Repartition.Always).build());
sparkNet.setListeners(new ScoreIterationListener(1));
sparkNet.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> rdd = sc.parallelize(list);

View File

@ -107,7 +107,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
expectedStatNames.addAll(c);
}
System.out.println(expectedStatNames);
// System.out.println(expectedStatNames);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
@ -119,7 +119,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
}
String statsAsString = stats.statsAsString();
System.out.println(statsAsString);
// System.out.println(statsAsString);
assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat

View File

@ -35,7 +35,7 @@ public class TestTimeSource {
long systemTime = System.currentTimeMillis();
long ntpTime = timeSource.currentTimeMillis();
long offset = ntpTime - systemTime;
System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
// System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
Thread.sleep(500);
}
}
@ -49,7 +49,7 @@ public class TestTimeSource {
long systemTime = System.currentTimeMillis();
long ntpTime = timeSource.currentTimeMillis();
long offset = ntpTime - systemTime;
System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
// System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next
Thread.sleep(500);
}

View File

@ -87,7 +87,7 @@ public class TestListeners extends BaseSparkTest {
net.fit(rdd);
List<String> sessions = ss.listSessionIDs();
System.out.println("Sessions: " + sessions);
// System.out.println("Sessions: " + sessions);
assertEquals(1, sessions.size());
String sid = sessions.get(0);
@ -95,15 +95,15 @@ public class TestListeners extends BaseSparkTest {
List<String> typeIDs = ss.listTypeIDsForSession(sid);
List<String> workers = ss.listWorkerIDsForSession(sid);
System.out.println(sid + "\t" + typeIDs + "\t" + workers);
// System.out.println(sid + "\t" + typeIDs + "\t" + workers);
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
System.out.println(lastUpdates);
// System.out.println(lastUpdates);
System.out.println("Static info:");
// System.out.println("Static info:");
for (String wid : workers) {
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
System.out.println(sid + "\t" + wid);
// System.out.println(sid + "\t" + wid);
}
assertEquals(1, typeIDs.size());

View File

@ -63,7 +63,7 @@ public class TestRepartitioning extends BaseSparkTest {
assertEquals(10, rdd2.partitions().size());
for (int i = 0; i < 10; i++) {
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
System.out.println("Partition " + i + " size: " + partition.size());
// System.out.println("Partition " + i + " size: " + partition.size());
assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
}
}
@ -170,7 +170,7 @@ public class TestRepartitioning extends BaseSparkTest {
List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
System.out.println(partitionCounts);
// System.out.println(partitionCounts);
List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList(
new Tuple2<>(0,29),
@ -185,7 +185,7 @@ public class TestRepartitioning extends BaseSparkTest {
JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112);
List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
System.out.println(partitionCountsAfter);
// System.out.println(partitionCountsAfter);
for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){
assertEquals(2, (int)t2._2());
@ -219,8 +219,8 @@ public class TestRepartitioning extends BaseSparkTest {
}
}
System.out.println("min: " + min + "\t@\t" + minIdx);
System.out.println("max: " + max + "\t@\t" + maxIdx);
// System.out.println("min: " + min + "\t@\t" + minIdx);
// System.out.println("max: " + max + "\t@\t" + maxIdx);
assertEquals(1, min);
assertEquals(2, max);
@ -244,7 +244,7 @@ public class TestRepartitioning extends BaseSparkTest {
for (int i = 0; i < 10; i++) {
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
System.out.println("Partition " + i + " size: " + partition.size());
// System.out.println("Partition " + i + " size: " + partition.size());
assertTrue(partition.size() >= 90 && partition.size() <= 110);
}
}

View File

@ -5,7 +5,7 @@ project(mkldnn-download NONE)
include(ExternalProject)
ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
GIT_TAG v1.1.3
GIT_TAG v1.2
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
CONFIGURE_COMMAND ""

View File

@ -999,14 +999,14 @@ namespace nd4j {
* set new order and shape in case of suitable array length (in-place operation)
* order - order to set
* shape - shape to set
*
* copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping
* if there was permute applied before or there are weird strides, then new buffer is allocated for array
*/
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape);
bool reshapei(const char order, const std::vector<Nd4jLong>& shape);
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const std::initializer_list<Nd4jLong>& shape);
bool reshapei(const std::vector<Nd4jLong>& shape);
bool reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
/**
* creates new array with corresponding order and shape, new array will point on _buffer of this array
@ -1015,8 +1015,8 @@ namespace nd4j {
*
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
*/
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const &;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) &&;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) const &;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) &&;
/**
* calculate strides and set given order
@ -1493,7 +1493,7 @@ namespace nd4j {
* @return
*/
bool isS() const;
template <typename T>
std::vector<T> asVectorT();

View File

@ -42,7 +42,7 @@ ND4J_EXPORT std::u32string NDArray::e(const Nd4jLong i) const;
////////////////////////////////////////////////////////////////////////
// copy constructor
NDArray::NDArray(const NDArray& other) {
_context = other._context;
_offset = 0;
@ -308,7 +308,7 @@ NDArray::NDArray(const std::u16string& u16string, nd4j::DataType dtype, nd4j::La
if (!unicode::isStringValidU16(u16string.data(), u16string.data() + u16string.size())) {
throw std::invalid_argument("NDArray::NDArray: invalid character in input string");
}
// one word that is why used 1
Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1);
@ -435,11 +435,11 @@ NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchConte
_offset = 0;
setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype));
memcpy(bufferAsT<int8_t>(), &offsets[0], 2 * sizeof(Nd4jLong));
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
if (dtype == DataType::UTF8) {
memcpy(data, str.data(), str.size());
}
@ -456,13 +456,13 @@ NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchConte
/////////////////////////////////////////////////////////////////////////
// constructors for vector of strings
NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& string, const nd4j::DataType dataType, nd4j::LaunchContext* context) {
if (!DataTypeUtils::isS(dataType))
throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used");
if (shape::prodLong(shape.data(), shape.size()) != string.size())
throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array");
for (const auto& str : string) {
if (!unicode::isStringValidU8(str, str + std::char_traits<char>::length(str)) ) {
throw std::invalid_argument("NDArray::NDArray: invalid character in input string");
@ -497,11 +497,11 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
setAttached(context->getWorkspace() != nullptr);
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e];
if (dataType == DataType::UTF16) {
unicode::utf8to16(string[e], cdata, std::char_traits<char>::length(string[e]));
@ -568,7 +568,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::stri
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e];
if (dataType == DataType::UTF16) {
unicode::utf8to16(string[e].data(), cdata, string[e].size());
@ -631,11 +631,11 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16s
setAttached(context->getWorkspace() != nullptr);
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e];
if (dtype == DataType::UTF16) {
memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t));
@ -699,9 +699,9 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e];
if (dtype == DataType::UTF16) {
memcpy(cdata, string[e], std::char_traits<char16_t>::length(string[e]) * sizeof(uint16_t));
@ -715,7 +715,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
}
};
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
}
@ -764,10 +764,10 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e];
if (dtype == DataType::UTF16) {
unicode::utf32to16(string[e].data(), cdata, string[e].size());
@ -781,7 +781,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
}
};
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
}
@ -831,9 +831,9 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e];
if (dtype == DataType::UTF16) {
unicode::utf32to16(string[e], cdata, std::char_traits<char32_t>::length(string[e]));
@ -847,7 +847,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
}
};
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
}
@ -887,8 +887,8 @@ bool NDArray::isC() const {
//////////////////////////////////////////////////////////////////////////
bool NDArray::isS() const {
return (dataType() == DataType::UTF8 ||
dataType() == DataType::UTF16 ||
return (dataType() == DataType::UTF8 ||
dataType() == DataType::UTF16 ||
dataType() == DataType::UTF32);
}
@ -1197,8 +1197,8 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
}
// memcpy is allowed only for same order && same ews (being equal to 1)
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
// memcpy is allowed only for same order c && same ews (being equal to 1)
if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
else {
NDArray::prepareSpecialUse({this}, {&other});
@ -1569,20 +1569,25 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector<int>& dimensions) cons
//////////////////////////////////////////////////////////////////////////
void NDArray::printShapeInfo(const char * msg) const {
//shape::printShapeInfo(_shapeInfo);
if (msg == nullptr)
shape::printShapeInfoLinear(_shapeInfo);
else {
int rank = shape::rank(_shapeInfo);
int lim = shape::shapeInfoLength(rank);
printf("%s: [", msg);
for (int i = 0; i < shape::shapeInfoLength(rank); i++) {
printf("%lld", (long long) _shapeInfo[i]);
if (i < lim - 1)
printf(", ");
}
printf("]\n");
int rank = shape::rank(_shapeInfo);
int lim = shape::shapeInfoLength(rank);
if(msg != nullptr)
printf("shapeInfo %s: [", msg);
else
printf("shapeInfo: [");
printf("%i, ", rank);
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
if(i == rank + 1)
printf(" ");
printf("%lld,", _shapeInfo[i]);
}
printf(" %lld,", shape::type(_shapeInfo));
printf("%lld,", shape::elementWiseStride(_shapeInfo));
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
fflush(stdout);
}
@ -1624,7 +1629,7 @@ void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) cons
if (e < limit - 1)
printf(", ");
}
}
}
else if (this->isS()) {
// todo do we need this print offsets
/*
@ -1773,7 +1778,7 @@ void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const {
printf("%s\n", this->e<bool>(0)?"true":"false");
}
else if (this->isS()) {
// todo do we need this
// todo do we need this
// printf("\"%lld\"\n", this->getOffset(e));
printf("\"%s\"\n", this->e<std::string>(0).c_str());
}
@ -1855,19 +1860,19 @@ void NDArray::updateStrides(const char order) {
//////////////////////////////////////////////////////////////////////////
// set new order and shape in case of suitable array length
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape) {
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
std::vector<Nd4jLong> vShape(shape);
return reshapei(order, vShape);
return reshapei(order, vShape, copyToNewBuff);
}
//////////////////////////////////////////////////////////////////////////
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape) {
return reshapei('c', shape);
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
return reshapei(ordering(), shape, copyToNewBuff);
}
//////////////////////////////////////////////////////////////////////////
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape) {
return reshapei('c', shape);
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) {
return reshapei(ordering(), shape, copyToNewBuff);
}
//////////////////////////////////////////////////////////////////////////
@ -1918,18 +1923,18 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
//////////////////////////////////////////////////////////////////////////
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const & {
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) const & {
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
newArr.reshapei(order, shape);
newArr.reshapei(order, shape, copyToNewBuff);
return newArr;
}
//////////////////////////////////////////////////////////////////////////
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) && {
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) && {
this->reshapei(order, shape);
this->reshapei(order, shape, copyToNewBuff);
return std::move(*this);
}
@ -1971,7 +1976,7 @@ bool NDArray::permutei(const std::initializer_list<int>& dimensions) {
//////////////////////////////////////////////////////////////////////////
bool NDArray::permutei(const std::vector<int>& dimensions) {
return permutei(dimensions.data(), dimensions.size());
return permutei(dimensions.data(), rankOf());
}
//////////////////////////////////////////////////////////////////////////
@ -1993,7 +1998,7 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
for (int e = 0; e < dimensions.size(); e++)
ivec[e] = dimensions[e];
return permutei(ivec.data(), ivec.size());
return permutei(ivec.data(), rankOf());
}
//////////////////////////////////////////////////////////////////////////
@ -2029,9 +2034,8 @@ NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && {
//////////////////////////////////////////////////////////////////////////
NDArray NDArray::permute(const std::vector<int>& dimensions) const &{
auto data = dimensions.data();
auto size = dimensions.size();
return permute(data, size);
return permute(dimensions.data(), rankOf());
}
//////////////////////////////////////////////////////////////////////////
@ -2043,7 +2047,8 @@ NDArray NDArray::permute(const std::vector<int>& dimensions) && {
//////////////////////////////////////////////////////////////////////////
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & {
return permute(dimensions.data(), dimensions.size());
return permute(dimensions.data(), rankOf());
}
//////////////////////////////////////////////////////////////////////////
@ -2106,12 +2111,12 @@ void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& targe
//////////////////////////////////////////////////////////////////////////
void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const {
permute(dimensions.data(), dimensions.size(), target);
permute(dimensions.data(), rankOf(), target);
}
//////////////////////////////////////////////////////////////////////////
void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const {
permute(dimensions.data(), dimensions.size(), target);
permute(dimensions.data(), rankOf(), target);
}
//////////////////////////////////////////////////////////////////////////
@ -2280,7 +2285,7 @@ template <typename T>
NDArray NDArray::asT() const{
auto result = isScalar() ? NDArray('c', {}, std::vector<double>{0.}, DataTypeUtils::fromT<T>(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.getSpecialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr);
NDArray::registerSpecialUse({&result}, {this});
@ -2298,15 +2303,15 @@ NDArray NDArray::asS() const {
auto dtype = DataTypeUtils::fromT<T>();
if (!(DataTypeUtils::isS(dtype)))
if (!(DataTypeUtils::isS(dtype)))
throw std::invalid_argument("NDArray::asS: invalid DataType used");
if (dtype == dataType()) {
Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf());
const auto nInputoffsets = bufferAsT<Nd4jLong>();
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(offsetsLength + nInputoffsets[lengthOf()], dtype, getContext()->getWorkspace(), true);
NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext());
res.setAttached(getContext()->getWorkspace() != nullptr);
@ -2319,7 +2324,7 @@ NDArray NDArray::asS() const {
registerPrimaryUse({ &res }, { this });
return res;
}
Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf());
std::vector<Nd4jLong> offsets(lengthOf() + 1);
@ -2353,7 +2358,7 @@ NDArray NDArray::asS() const {
NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext());
res.setAttached(getContext()->getWorkspace() != nullptr);
preparePrimaryUse({ &res }, { this });
memcpy(res.bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
@ -2362,7 +2367,7 @@ NDArray NDArray::asS() const {
const auto inData = bufferAsT<int8_t>() + offsetsLength;
auto func = PRAGMA_THREADS_FOR{
for (int e = start; e < stop; e += increment) {
for (int e = start; e < stop; e++) {
auto cdata = outData + offsets[e];
auto end = nInputoffsets[e + 1];
auto idata = inData + nInputoffsets[e];
@ -2403,7 +2408,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asS, () const, LIBND
////////////////////////////////////////////////////////////////////////
NDArray NDArray::asT(DataType dtype) const {
if (isS() && !DataTypeUtils::isS(dtype))
throw std::runtime_error("NDArray::asT: you can't use this method on String array with not string DataType!");
@ -3221,7 +3226,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LI
//////////////////////////////////////////////////////////////////////////
// set new order and shape in case of suitable array length
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape, const bool copyToNewBuff) {
// check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
@ -3293,19 +3298,15 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
Nd4jLong *shapeInfoNew;
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew);
// we can do this only if there was no permute applied, or there are no weird strides
if (canReshape) {
if(ordering() == 'c' && order == 'f')
throw std::invalid_argument("NDArray::reshapei(order, shape): in case of reshapeC it doesn't make sense to reshape from c order to f order !");
shape::setEws(shapeInfoNew, arrLength);
setShapeInfo(shapeInfoNew);
}
else {
NDArray temp(order, shape, dataType(), getContext());
this->applyTransform(transform::Assign, temp, nullptr);
if(copyToNewBuff)
this->applyTransform(transform::Assign, temp, nullptr);
*this = std::move(temp);
}
@ -3463,9 +3464,9 @@ NDArray NDArray::dup(const char newOrder) const {
if (isS()) {
if (dataType() == DataType::UTF8) {
std::vector<std::string> strings(lengthOf());
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
strings[i] = std::move(this->e<std::string>(i));
}
};
@ -3478,7 +3479,7 @@ NDArray NDArray::dup(const char newOrder) const {
std::vector<std::u16string> strings(lengthOf());
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
strings[i] = std::move(this->e<std::u16string>(i));
}
};
@ -3490,7 +3491,7 @@ NDArray NDArray::dup(const char newOrder) const {
std::vector<std::u32string> strings(lengthOf());
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
strings[i] = std::move(this->e<std::u32string>(i));
}
};
@ -3521,7 +3522,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
if (isS()) {
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
if (dataType() == DataType::UTF8) {
for (int e = 0; e < this->lengthOf(); e++) {
auto s1 = this->e<std::string>(e);
@ -3585,7 +3586,7 @@ std::string NDArray::e(const Nd4jLong i) const {
if (i == lengthOf())
throw std::runtime_error("Can't get std::string for index out of range");
if (this->dataType() == DataType::UTF16) {
auto u16 = this->e<std::u16string>(i);
std::string s;
@ -4846,7 +4847,7 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
auto shapeOf = shape::shapeOf(newShapeInfo);
auto stridesOf = shape::stride(newShapeInfo);
Nd4jLong offset(0), subArrLen(1);
Nd4jLong offset = 0;
int n(isStrided ? 3 : 2), first, last, stride;
for (int d = rank - 1; d >= 0; --d) {
@ -4863,29 +4864,31 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
if(shapeOf[d] != 1)
stridesOf[d] *= stride;
}
}
subArrLen *= shapeOf[d];
Nd4jLong *newShapeInfo2 = newShapeInfo;
if(!keepUnitiesInShape) {
std::vector<int> dimsWithUnities;
for (uint d = 0; d < rank; ++d)
if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1)
dimsWithUnities.push_back(d);
if(!dimsWithUnities.empty())
newShapeInfo2 = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace());
}
// check if there is possibility to set ews = 1
shape::setEws(newShapeInfo, subArrLen);
shape::checkStridesEwsAndOrder(newShapeInfo2);
NDArray result(_buffer, ShapeDescriptor(newShapeInfo), getContext(), offset + getBufferOffset());
NDArray result(_buffer, ShapeDescriptor(newShapeInfo2), getContext(), offset + getBufferOffset());
result._isView = true;
if(!keepUnitiesInShape) {
const int coeff = isStrided ? 3 : 2;
std::vector<Nd4jLong> nonUnitDims;
for (int d = 0; d < rank; ++d)
if(!(idx[coeff*d] != idx[coeff*d+1] && newShapeInfo[d+1] == 1))
nonUnitDims.push_back(newShapeInfo[d+1]);
if(nonUnitDims.size() != rank)
result.reshapei(nonUnitDims);
}
RELEASE(newShapeInfo, getContext()->getWorkspace());
if(newShapeInfo != newShapeInfo2)
RELEASE(newShapeInfo2, getContext()->getWorkspace());
return result;
}

View File

@ -179,7 +179,7 @@ namespace graph {
nd4j_debug("Embedded graph execution finished. %i variable(s) migrated\n", cnt);
} else if (node->hasCustomOp()) {
// if we have something to execute - lets just execute it.
// now, if we have something to execute - lets just execute it.
auto status = node->getCustomOp()->execute(&context);
if (status != ND4J_STATUS_OK)
return status;
@ -494,8 +494,10 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
nd4j::memory::MemoryRegistrator::getInstance()->setGraphMemoryFootprintIfGreater(h, m);
}
if (tempFlow)
if (tempFlow) {
delete flowPath;
__variableSpace->setFlowPath(nullptr);
}
return Status::OK();
}

View File

@ -98,7 +98,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
auto func = PRAGMA_THREADS_FOR {
Nd4jLong coords[MAX_RANK];
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
shape::index2coords(i, target.getShapeInfo(), coords);
const auto zOffset = shape::getOffset(target.getShapeInfo(), coords);
@ -152,7 +152,7 @@ static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) {
auto y = reinterpret_cast<T *>(yBuffer);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto temp = x[i];
x[i] = y[i];
y[i] = temp;
@ -266,7 +266,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
if(result.ordering() == 'c') { // ews == 1 always here
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES);
}
@ -277,7 +277,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
else {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto xOffset = result.getOffset(i);
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES);
@ -377,7 +377,7 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
// loop through input array
auto func = PRAGMA_THREADS_FOR {
Nd4jLong coords[MAX_RANK];
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
shape::index2coords(i, output.getShapeInfo(), coords);
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords);

View File

@ -22,7 +22,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment)
for (auto e = start; e < stop; e++)
z[e] = func(f[e], s[e], t[e]);
};
@ -31,7 +31,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
if (f == z) {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto tOffset = this->getOffset(e);
auto uOffset = second.getOffset(e);
auto vOffset = third.getOffset(e);
@ -44,7 +44,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
} else {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto tOffset = this->getOffset(e);
auto uOffset = second.getOffset(e);
auto vOffset = third.getOffset(e);
@ -93,7 +93,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment)
for (auto e = start; e < stop; e++)
z[e] = func(f[e], s[e]);
};
@ -102,7 +102,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
if (f == z) {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other.getOffset(e);
@ -114,7 +114,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
} else {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other.getOffset(e);
auto zOffset = target.getOffset(e);
@ -156,7 +156,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment)
for (auto e = start; e < stop; e++)
z[e] = func(f[e]);
};
@ -165,7 +165,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
if (f == z) {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e);
f[xOffset] = func(f[xOffset]);
@ -176,7 +176,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
} else {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e);
auto zOffset = target.getOffset(e);
@ -217,7 +217,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment)
for (auto e = start; e < stop; e++)
z[e] = func(e, f[e]);
};
@ -226,7 +226,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
if (f == z) {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e);
f[xOffset] = func(e, f[xOffset]);
@ -237,7 +237,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
} else {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e);
auto zOffset = target.getOffset(e);
@ -283,7 +283,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment)
for (auto e = start; e < stop; e++)
z[e] = func((Nd4jLong) e, f[e], s[e]);
};
@ -292,7 +292,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
if (f == z) {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other.getOffset(e);
@ -304,7 +304,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
} else {
auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other.getOffset(e);
auto zOffset = target.getOffset(e);

View File

@ -163,15 +163,44 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
#else
auto loopKind = nd4j::LoopKind::deduceKindOfLoopBroadcast(hXShapeInfo, hYShapeInfo, hZShapeInfo);
auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), LIBND4J_TYPES);
};
auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo);
auto numTads = xLen / yLen;
Nd4jLong numTads = 0;
switch (loopKind) {
case nd4j::LoopKind::BROADCAST_SCALAR_X: {
numTads = shape::length(hXShapeInfo);
}
break;
case nd4j::LoopKind::BROADCAST_SCALAR_Y: {
numTads = shape::length(hYShapeInfo);
}
break;
case nd4j::LoopKind::BROADCAST_3D: {
numTads = shape::sizeAt(hZShapeInfo, 0);
}
break;
case nd4j::LoopKind::BROADCAST_4D: {
numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1);
}
break;
case nd4j::LoopKind::BROADCAST_5D: {
numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1);
}
break;
default: {
auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo);
numTads = xLen / yLen;
}
}
samediff::Threads::parallel_tad(func, 0, numTads);
#endif
}

View File

@ -1291,7 +1291,7 @@ void pullRowsGeneric(void *vx,
_threads = nd4j::math::nd4j_min<int>(_threads, nd4j::Environment::getInstance()->maxThreads());
auto func = PRAGMA_THREADS_FOR {
for (auto idx = start; idx < stop; idx += increment) {
for (auto idx = start; idx < stop; idx++) {
auto xTadOffsetForBlock = tadOffsets[indexes[idx]];
auto zTadOffsetForBlock = zTadOffsets[idx];
@ -1356,7 +1356,7 @@ void tearGeneric(void *vx,
auto numTads = shape::length(hXShapeInfo) / tadLength;
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto hZ = reinterpret_cast<T *>(targets[i]);
auto s = hX + tadOffsets[i];
@ -1478,7 +1478,7 @@ void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZS
auto dZ = reinterpret_cast<T **>(dz);
auto func = PRAGMA_THREADS_FOR {
for (auto f = start; f < stop; f += increment) {
for (auto f = start; f < stop; f++) {
auto hX = reinterpret_cast<T *>(dX[f]);
//auto hZ = reinterpret_cast<T *>(dZ[f]);

View File

@ -52,7 +52,7 @@ namespace nd4j {
TypeCast::convertGeneric<T2, T>(nullptr, tmp, length, buffer);
#else
auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment)
for (auto e = start; e < stop; e++)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
@ -110,7 +110,7 @@ namespace nd4j {
TypeCast::convertGeneric<float, T>(nullptr, tmp, length, buffer);
#else
auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment)
for (auto e = start; e < stop; e++)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
@ -138,7 +138,7 @@ namespace nd4j {
#else
auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment)
for (auto e = start; e < stop; e++)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
@ -164,7 +164,7 @@ namespace nd4j {
TypeCast::convertGeneric<float16, T>(nullptr, tmp, length, buffer);
#else
auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment)
for (auto e = start; e < stop; e++)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};

View File

@ -58,6 +58,7 @@ namespace nd4j {
virtual void putVariable(int id, Variable *variable);
virtual void putVariable(int id, NDArray *array);
virtual void putVariable(int id, int idx, NDArray *array);
virtual void putVariable(int id, int idx, NDArray &array);
virtual void putVariable(int id, int idx, Variable *array);
virtual void replaceVariable(Variable *variable);

View File

@ -100,6 +100,7 @@ namespace nd4j {
virtual void putVariable(int id, Variable *variable);
virtual void putVariable(int id, NDArray *array);
virtual void putVariable(int id, int idx, NDArray *array);
virtual void putVariable(int id, int idx, NDArray &array);
virtual void putVariable(int id, int idx, Variable *array);
virtual void dropVariable(std::pair<int,int> &pair);

View File

@ -1088,8 +1088,23 @@ namespace nd4j {
if (e < node->input()->size() - 1)
nd4j_printf(", ", "");
}
if (node->opType() == OpType_CUSTOM) {
auto ctx = node->protoContext();
if (ctx->getIArguments()->size() > 0) {
printf("]; iArgs: [");
for (int e = 0; e < ctx->getIArguments()->size(); e++) {
printf("%i", ctx->getIArguments()->at(e));
if (e < ctx->getIArguments()->size() - 1)
nd4j_printf(", ", "");
}
}
}
nd4j_printf("]; \n", "");
// printf("\n");
fflush(stdout);
}

View File

@ -60,8 +60,11 @@ namespace nd4j {
result->_name = this->_name;
result->_index = this->_index;
if (this->_ndarray != nullptr)
if (this->_ndarray != nullptr) {
result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering()));
result->_readOnly = false;
result->_removable = true;
}
if (this->_list != nullptr)
result->_list = this->_list->clone();

View File

@ -191,6 +191,9 @@ namespace nd4j {
_current->putVariable(id, array);
}
void nd4j::graph::VariableProxy::putVariable(int id, int idx, NDArray &array) {
_current->putVariable(id, idx, array);
}
void VariableProxy::putVariable(int id, int idx, NDArray *array) {
_current->putVariable(id, idx, array);

View File

@ -263,19 +263,19 @@ namespace nd4j {
void nd4j::graph::VariableSpace::putVariable(int id, Variable *variable) {
// we don't want to add variables more then once
if (_variables.count(id) > 0 || _temporary.count(id) > 0) {
// nd4j_verbose("Trying to update variable for node_%i\n", id);
auto local = id < 0 ? _variables.at(id) : _temporary.at(id);
if (!local->hasNDArray() && variable->hasNDArray()) {
// nd4j_verbose("Saving variable for node_%i\n", id);
local->setNDArray(variable->getNDArray());
// we're inheriting this from Variable
local->markReadOnly(variable->isReadOnly());
local->markRemovable(variable->isRemovable());
}
return;
}
//nd4j_debug("Adding Variable to Space: id: %i; Array is null: %i;\n", id, variable->getNDArray() == nullptr);
_varmap.lock();
_handles->emplace_back(variable);
@ -314,6 +314,21 @@ namespace nd4j {
}
}
void nd4j::graph::VariableSpace::putVariable(int id, int idx, NDArray &array) {
auto *var = new nd4j::graph::Variable(&array, "", id, idx);
var->markRemovable(false);
var->markReadOnly(true);
// let's see if this op needs
bool d = this->hasVariable(id, idx);
this->putVariable(id, var);
// if var for this nodeid already exists - we'll just delete variable
if (d)
delete var;
}
void nd4j::graph::VariableSpace::putVariable(int id, NDArray *array) {
auto *var = new nd4j::graph::Variable(array);
this->putVariable(id, var);

View File

@ -24,6 +24,7 @@
#include <pointercast.h>
#include <dll.h>
#include <string>
#include <vector>
namespace nd4j {
namespace graph {
@ -65,6 +66,9 @@ namespace nd4j {
// total amount of memory used during execution
Nd4jLong _memoryTotal = 0L;
std::vector<std::string> _inputShapes;
std::vector<std::string> _outputShapes;
public:
NodeProfile() = default;
~NodeProfile() = default;
@ -84,10 +88,15 @@ namespace nd4j {
void setObjectsSize(Nd4jLong bytes);
void setTotalSize(Nd4jLong bytes);
Nd4jLong getActivationsSize();
Nd4jLong getTemporarySize();
Nd4jLong getObjectsSize();
Nd4jLong getTotalSize();
void addInputShape(Nd4jLong *shapeInfo);
void addOutputShape(Nd4jLong *shapeInfo);
Nd4jLong getActivationsSize() const;
Nd4jLong getTemporarySize() const;
Nd4jLong getObjectsSize() const;
Nd4jLong getTotalSize() const;
Nd4jLong getExecutionTime() const;
std::string& name();

View File

@ -21,6 +21,8 @@
#include <graph/profiling/GraphProfile.h>
#include <helpers/logger.h>
#include <chrono>
#include <templatemath.h>
#include <algorithm>
namespace nd4j {
namespace graph {
@ -184,9 +186,26 @@ namespace nd4j {
if (_profiles.empty())
nd4j_printf("No nodes in graph\n","");
for (auto v: _profiles)
// printint out stuff
std::vector<NodeProfile*> sorted;
for (auto v: _profiles) {
v->printOut();
sorted.emplace_back(v);
}
if (_profiles.size() > 1) {
// building hot spots
std::sort(sorted.begin(), sorted.end(), [](const NodeProfile *a, const NodeProfile *b) -> bool {
return a->getExecutionTime() > b->getExecutionTime();
});
nd4j_printf("\nTop 30 reports by EXEC:\n", "");
auto limit = nd4j::math::nd4j_min<int>(30, sorted.size());
for (int e = 0; e < limit; e++) {
sorted[e]->printOut();
}
}
nd4j_printf("\nSpecial timers:\n", "");
if (_timings.empty())
nd4j_printf("No special timers were set\n","");

View File

@ -32,7 +32,7 @@ namespace nd4j {
// graph->printOut();
// warm up
for (int e = 0; e < 1000; e++) {
for (int e = 0; e < iterations; e++) {
FlowPath fp;
auto _vs = varSpace->clone();

View File

@ -20,6 +20,7 @@
#include <helpers/logger.h>
#include <graph/profiling/NodeProfile.h>
#include <helpers/ShapeUtils.h>
namespace nd4j {
namespace graph {
@ -35,9 +36,23 @@ namespace nd4j {
nd4j_printf(" Memory: ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", _memoryActivations / _merges, _memoryTemporary / _merges, _memoryObjects / _merges, _memoryTotal / _merges);
nd4j_printf(" Time: PREP: %lld ns; EXEC: %lld ns; TTL: %lld ns;\n", _preparationTime / _merges, _executionTime / _merges, _totalTime / _merges);
nd4j_printf(" PREP: INPUT: %lld ns; SHAPE: %lld ns; ARRAY: %lld ns;\n", _inputTime / _merges, _shapeTime / _merges, _arrayTime / _merges);
std::string inputs;
std::string outputs;
int cnt = 0;
for (const auto &v: _inputShapes)
inputs += v + " ";
for (const auto &v: _outputShapes)
outputs += v + " ";
nd4j_printf(" Inputs: %s\n", inputs.c_str());
nd4j_printf(" Outputs: %s\n", outputs.c_str());
};
Nd4jLong NodeProfile::getActivationsSize() {
Nd4jLong NodeProfile::getActivationsSize() const {
return _memoryActivations;
}
@ -53,15 +68,15 @@ namespace nd4j {
_inputTime = time;
}
Nd4jLong NodeProfile::getTemporarySize() {
Nd4jLong NodeProfile::getTemporarySize() const{
return _memoryTemporary;
}
Nd4jLong NodeProfile::getObjectsSize() {
Nd4jLong NodeProfile::getObjectsSize() const{
return _memoryObjects;
}
Nd4jLong NodeProfile::getTotalSize() {
Nd4jLong NodeProfile::getTotalSize() const{
return _memoryTotal;
}
@ -97,6 +112,18 @@ namespace nd4j {
_memoryTotal = bytes;
}
Nd4jLong NodeProfile::getExecutionTime() const {
return _executionTime;
}
void NodeProfile::addInputShape(Nd4jLong *shapeInfo) {
_inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
}
void NodeProfile::addOutputShape(Nd4jLong *shapeInfo) {
_outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
}
void NodeProfile::merge(NodeProfile *other) {
_merges += other->_merges;
_memoryObjects += other->_memoryObjects;
@ -110,6 +137,9 @@ namespace nd4j {
_shapeTime += other->_shapeTime;
_arrayTime += other->_arrayTime;
_inputTime += other->_inputTime;
_inputShapes = other->_inputShapes;
_outputShapes = other->_outputShapes;
}
std::string& NodeProfile::name() {
@ -129,6 +159,9 @@ namespace nd4j {
_shapeTime = other->_shapeTime;
_arrayTime = other->_arrayTime;
_inputTime = other->_inputTime;
_inputShapes = other->_inputShapes;
_outputShapes = other->_outputShapes;
}
}
}

View File

@ -37,12 +37,13 @@ namespace nd4j {
class ND4J_EXPORT LoopKind {
public:
enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON};
enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D };
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
};
@ -82,6 +83,57 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd
return COMMON;
}
LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
auto xRank = shape::rank(xShapeInfo);
auto yRank = shape::rank(yShapeInfo);
auto zRank = shape::rank(zShapeInfo);
auto xOrder = shape::order(xShapeInfo);
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo);
auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zShapeInfo);
bool bNDLoopsRanks = (xRank == zRank && yRank <= xRank && yRank >= 2);
int countUnityDimsInY = 0, countUnityDimsInX = 0;
for (int i = 0; i < xRank; i++) {
if (i < yRank)
countUnityDimsInY += (1 == shape::sizeAt(yShapeInfo, i)) ? 1 : 0;
countUnityDimsInX += (1 == shape::sizeAt(xShapeInfo, i)) ? 1 : 0;
}
bool bNotCommonVectorCase = (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1);
if (3 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
return nd4j::LoopKind::BROADCAST_3D;
if (4 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
return nd4j::LoopKind::BROADCAST_4D;
if (5 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
return nd4j::LoopKind::BROADCAST_5D;
if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) {
// we validate that shapes are equal till the last dim
for (int e = 0; e < xRank - 1; e++) {
if (xShapeInfo[e+1] != yShapeInfo[e+1])
return COMMON;
}
// now, if one of the shapes has 1 as last dim
auto detect = xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0;
if (detect == 1)
return nd4j::LoopKind::BROADCAST_SCALAR_Y;
else if (detect == -1)
return nd4j::LoopKind::BROADCAST_SCALAR_X;
}
return nd4j::LoopKind::COMMON;
}
//////////////////////////////////////////////////////////////////////////////
LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {

View File

@ -30,15 +30,15 @@
namespace nd4j {
class ND4J_EXPORT ShapeBuilders {
public:
public:
static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr);
static Nd4jLong* createVectorShapeInfo(const nd4j::DataType dataType, const Nd4jLong length, nd4j::memory::Workspace* workspace = nullptr);
/**
* create shapeInfo for given order basing on shape stored in shapeOnly vector
* memory allocation for shapeInfo is on given workspace
*/
*/
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace = nullptr);
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong>& shapeOnly, memory::Workspace* workspace = nullptr);
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::initializer_list<Nd4jLong>& shapeOnly, memory::Workspace* workspace = nullptr);
@ -51,6 +51,13 @@ namespace nd4j {
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr);
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
/**
* allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
*/
static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);

View File

@ -50,11 +50,13 @@ namespace nd4j {
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
// evaluate shapeInfo of permuted array
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
// evaluate shapeInfo of transposed array
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace);
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
@ -97,6 +99,8 @@ namespace nd4j {
static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo);
static std::string strideAsString(const NDArray* array);
static std::string shapeInfoAsString(const Nd4jLong* shapeInfo);
static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo);
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
@ -176,6 +180,17 @@ namespace nd4j {
return (numStrings + 1) * sizeof(Nd4jLong);
}
/**
* This method selects strides based on dimentions required for broadcasting
* @param const pointer to input (Y) shape info for strides selection
* @param rank of input (X) to broadcasting
* @param dimentions size
* @param const pointer to dimentions for broadcasting
* @param pointer to output strides have to be pre allocated by 0
* @return
*/
static void copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides);
/*
* check whether arr1/arr2 is sub-array of arr2/arr1,
* this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1

View File

@ -68,7 +68,7 @@ namespace nd4j {
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; // shape of sub-arrays (same for all for them)
auto oPtr = new Nd4jLong[numOfSubArrs];
if (numOfSubArrs > 0)

View File

@ -49,7 +49,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
case nd4j::LoopKind::EWS1: {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad);
@ -70,7 +70,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
case nd4j::LoopKind::EWSNONZERO: {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad);
@ -91,7 +91,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
case nd4j::LoopKind::RANK1: {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad);
@ -114,7 +114,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
shape::updateStrides(2, tadShape, newStride, 'c');
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad);
@ -141,7 +141,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
shape::updateStrides(3, tadShape, newStride, 'c');
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad);
@ -170,7 +170,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
shape::updateStrides(4, tadShape, newStride, 'c');
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad);
@ -201,7 +201,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
shape::updateStrides(5, tadShape, newStride, 'c');
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad);
@ -234,7 +234,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad);
@ -258,7 +258,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad);
@ -284,7 +284,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad);

View File

@ -43,23 +43,30 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
NDArray aPR = a->permute(permutAt);
NDArray bPR = b->permute(permutBt);
// check whether permutation is necessary
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
// check whether reshape is necessary
if(!aPR.isSameShape(shapeAt))
aPR.reshapei( shapeAt);
if(!bPR.isSameShape(shapeBt))
bPR.reshapei( shapeBt);
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
c->reshapei(outShape);
if(aP != aPR)
delete aPR;
if(bP != bPR)
delete bPR;
if(a != aP)
delete aP;
if(b != bP)
delete bP;
return c;
}
//////////////////////////////////////////////////////////////////////////
void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) {
@ -67,32 +74,38 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
std::vector<Nd4jLong> shapeAt, shapeBt;
ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt);
NDArray *cP(c), *cPR(c);
// check whether permutation is required
if(!permutForC.empty())
cP = new NDArray(c->permute(permutForC));
NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC));
auto aPR = a->permute(permutAt);
auto bPR = b->permute(permutBt);
// check whether permutation is necessary
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
// check whether reshape is necessary
if(!aPR.isSameShape(shapeAt))
aPR.reshapei(shapeAt);
if(!bPR.isSameShape(shapeBt))
bPR.reshapei(shapeBt);
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
std::vector<Nd4jLong> requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)};
mmul(&aPR, &bPR, cPR, 1.0, 0.0);
NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false));
mmul(aPR, bPR, cPR, 1.0, 0.0);
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
cP->assign(cPR);
if(cPR != c)
if(aP != aPR)
delete aPR;
if(bP != bPR)
delete bPR;
if(a != aP)
delete aP;
if(b != bP)
delete bP;
if(cP != cPR)
delete cPR;
if(cP != c)
if(c != cP)
delete cP;
}
@ -129,7 +142,7 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
if(!whatToDoWithC.empty()) {
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
for(int i = 0; i < cArrs.size()-1; ++i)
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false)); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
}
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
@ -208,7 +221,7 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
if(isAVector && bRank == 2) {
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N}
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N}
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
delete A2;
delete C2;

View File

@ -139,5 +139,15 @@ namespace nd4j {
return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace);
}
////////////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) {
Nd4jLong *outShapeInfo = nullptr;
ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong);
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo);
return outShapeInfo;
}
}

View File

@ -75,10 +75,23 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
permutBt = axesB;
permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());
// if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case
uint i1, i2;
for(i1 = 0; i1 < aRank; ++i1)
if(permutAt[i1] != i1)
break;
if(i1 == aRank)
permutAt = {};
for(i2 = 0; i2 < bRank; ++i2)
if(permutBt[i2] != i2)
break;
if(i2 == bRank)
permutBt = {};
Nd4jLong n2 = 1;
for (int i = 0; i < axeAsize; i++)
n2 *= aShapeInfo[axesA[i] + 1];
shapeAt = {-1, n2};
shapeAt = {shape::length(aShapeInfo) / n2, n2};
std::vector<Nd4jLong> oldShapeA;
oldShapeA.resize(list_A.size());
@ -89,7 +102,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
Nd4jLong n3 = 1;
for (int i = 0; i < axeBsize; i++)
n3 *= bShapeInfo[axesB[i] + 1];
shapeBt = {n3, -1};
shapeBt = {n3, shape::length(bShapeInfo) / n3};
std::vector<Nd4jLong> oldShapeB;
oldShapeB.resize(list_B.size());
@ -300,32 +313,37 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
return outShape;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
if (!arr.nonNull())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr!");
if (!arr.nonNull())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
if (rank != arr.rankOf())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable!");
if (rank != arr.rankOf())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
auto shapeInfoLength = shape::shapeInfoLength(rank);
// allocate memory for new array - shapeInfo
auto shapeInfoLength = shape::shapeInfoLength(rank);
Nd4jLong *shapeInfoNew = nullptr;
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
// copy arr _shapeInfo into new array
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
// perform buffer permutation
shape::doPermuteShapeInfo(shapeInfoNew, dimensions);
// allocate memory for new array - shapeInfo
Nd4jLong *shapeInfoNew = nullptr;
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
ShapeDescriptor descriptor(shapeInfoNew);
RELEASE(shapeInfoNew, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
// copy arr _shapeInfo into new array
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
// perform buffer permutation
shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());
if(setContigStrides)
shape::updateStrides(shapeInfoNew, arr.ordering());
ShapeDescriptor descriptor(shapeInfoNew);
RELEASE(shapeInfoNew, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array
@ -337,14 +355,14 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of transposed array
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace) {
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
int rank = arr.rankOf();
std::vector<int> dimensions(rank);
for (int i = 0; i < rank; ++i)
dimensions[i] = rank - 1 - i;
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace);
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, setContigStrides);
}
//////////////////////////////////////////////////////////////////////////
@ -653,6 +671,26 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd
return result;
}
std::string ShapeUtils::shapeInfoAsString(const Nd4jLong* shapeInfo) {
if(!shapeInfo)
throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !");
std::string result;
int len = shape::shapeInfoLength(shapeInfo[0]);
result.append("[");
for (int e = 0; e < len; e++) {
result += flatbuffers::NumToString(shapeInfo[e]);
if (e < len - 1)
result.append(", ");
}
result.append("]");
return result;
}
std::string ShapeUtils::shapeAsString(const int rank, const Nd4jLong* shapeInfo) {
if(!shapeInfo)
@ -1019,6 +1057,29 @@ std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const
return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
}
void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides) {
int yRank = shape::rank(inShapeInfo);
auto yOrigStride = shape::stride(inShapeInfo);
if (yRank == nRank) {
for (int i = 0; i < yRank; ++i) {
// x[2,3,4] * y[2,1,4] = z[2,3,4]
outStrides[i] = (1 == shape::sizeAt(inShapeInfo, i)) ? 0 : yOrigStride[i];
}
}
else {
auto dimEx = nd4j::ShapeUtils::evalDimsToExclude(nRank, dimsSize, dims);
for (int i = 0, it = 0; i < nRank; ++i) {
auto nCount = std::count(dimEx.cbegin(), dimEx.cend(), i);
outStrides[i] = (0 == nCount) ? yOrigStride[it++] : 0;
if (it == yRank)
break;
}
}
}
////////////////////////////////////////////////////////////////////////////////
/*
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {

File diff suppressed because it is too large Load Diff

View File

@ -40,6 +40,7 @@
#endif
#include <helpers/TAD.h>
#include <helpers/LoopKind.h>
#include "legacy_ops.h"
@ -122,6 +123,7 @@ namespace functions {
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
nd4j::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop);
@ -149,6 +151,7 @@ namespace functions {
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
nd4j::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop);

View File

@ -14,9 +14,9 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <loops/TrueBroadcastHelper.h>
#include <ops/ops.h>
@ -24,226 +24,268 @@
using namespace simdOps;
namespace nd4j {
namespace helpers {
namespace nd4j {
namespace helpers {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
template<typename OpType>
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
template<typename OpType>
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank &&
1 == yArr.ews() && 'c' == yArr.ordering() &&
1 == zArr.ews() && 'c' == zArr.ordering());
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() &&
1 == yArr.ews() && 'c' == yArr.ordering() &&
1 == zArr.ews() && 'c' == zArr.ordering());
if (bSpecialCase) {
auto yLen = (uint32_t)yArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
for (uint32_t i = start; i < stop; i++) {
auto rZ = z + (i * yLen);
auto v = x[i];
for (uint32_t j = 0; j < yLen; j++) {
rZ[j] = OpType::op(v, y[j]);
}
}
};
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
return;
if (bSpecialCase && yArr.isColumnVector() && 1 == xArr.sizeAt(-1) ) {
auto yLen = (uint32_t)yArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
for (uint32_t i = start; i < stop; i++) {
auto rZ = z + (i * yLen);
auto v = x[i];
for (uint32_t j = 0; j < yLen; j++) {
rZ[j] = OpType::op(v, y[j]);
}
}
};
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
return;
}
auto yShapeInt = yArr.getShapeAsVectorInt();
auto xShapeInt = xArr.getShapeAsVectorInt();
auto nCountY = std::count_if(yShapeInt.cbegin(), yShapeInt.cend(), [](int i) { return i == 1; });
auto nCountX = std::count_if(xShapeInt.cbegin(), xShapeInt.cend(), [](int i) { return i == 1; });
bool bSpecialCase2 = (xRank == zRank && yRank == zRank && 1 == xArr.sizeAt(-1) && 1 == yArr.sizeAt(-2) && 1 == nCountY && 1 == nCountX);
if (bSpecialCase && bSpecialCase2) {
int zDim1 = zArr.sizeAt(-2);
int zDim2 = zArr.sizeAt(-1);
int nLen = zArr.lengthOf() / yArr.sizeAt(-1);
auto func = PRAGMA_THREADS_FOR{
for (uint32_t total = start; total < stop; total++) {
uint32_t i = total / zDim1;
uint32_t j = total % zDim1;
uint32_t index = (i * zDim1) + j;
auto rZ = z + (index * zDim2);
auto rY = y + (i * zDim2);
auto rX = x[index];
for (uint32_t n = 0; n < zDim2; n++) {
rZ[n] = OpType::op(rX, rY[n]);
}
}
};
samediff::Threads::parallel_tad(func, 0, nLen, 1);
return;
}
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
}
else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
}
else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y, typename Z>
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
}
else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
}
else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y>
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X>
template<typename OpType>
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
X* z = reinterpret_cast<X*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
}
else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
}
else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X>
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
}
/*
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
*/
}
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y, typename Z>
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y>
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X>
template<typename OpType>
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
X* z = reinterpret_cast<X*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X>
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
}
/*
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
*/
}
}

View File

@ -25,6 +25,7 @@
#include <LoopKind.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <helpers/ShapeUtils.h>
using namespace simdOps;
@ -75,6 +76,7 @@ namespace functions {
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
nd4j::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
@ -88,7 +90,7 @@ namespace functions {
xTadShapeInfo,
xTadOffset,
zTadShapeInfo,
zTadOffset, start, stop), BROADCAST_OPS);
zTadOffset, loopKind, start, stop), BROADCAST_OPS);
}
template <typename X, typename Y, typename Z>
@ -105,6 +107,7 @@ namespace functions {
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
nd4j::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop) {
@ -142,7 +145,14 @@ namespace functions {
auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zTadShapeInfo);
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
const nd4j::LoopKind::Kind kindOfLoop =
(loopKind == nd4j::LoopKind::BROADCAST_SCALAR_X ||
loopKind == nd4j::LoopKind::BROADCAST_SCALAR_Y ||
loopKind == nd4j::LoopKind::BROADCAST_3D ||
loopKind == nd4j::LoopKind::BROADCAST_4D ||
loopKind == nd4j::LoopKind::BROADCAST_5D)
? loopKind : nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) {
for (auto i = start; i < stop; i++) {
@ -163,6 +173,131 @@ namespace functions {
for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
}
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_X){
// this loop effectively turns broadcast into series of scalar ops
auto loopLength = yShapeInfo[shape::rank(yShapeInfo)];
for (auto i = start; i < stop; i++) {
auto oY = y + (i * loopLength);
auto oZ = z + (i * loopLength);
const auto oX = x[i];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < loopLength; f++)
oZ[f] = OpType::op(oX, oY[f]);
}
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_Y){
// this loop effectively turns broadcast into series of scalar ops
auto loopLength = xShapeInfo[shape::rank(xShapeInfo)];
for (auto i = start; i < stop; i++) {
auto oX = x + (i * loopLength);
auto oZ = z + (i * loopLength);
const auto oY = y[i];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < loopLength; f++)
oZ[f] = OpType::op(oX[f], oY);
}
}
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_3D) {
int xRank = shape::rank(xShapeInfo);
int yRank = shape::rank(yShapeInfo);
auto xStrides = shape::stride(xShapeInfo);
auto zStrides = shape::stride(zShapeInfo);
Nd4jLong yStrides[3] = { 0,0,0 };
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
for (uint32_t index0 = start; index0 < stop; index0++) {
PRAGMA_OMP_SIMD
for (uint32_t index1 = 0; index1 < nSize1; index1++) {
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2);
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2);
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2);
*rZ = OpType::op(*rX, *rY);
}
}
}
}
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_4D) {
int xRank = shape::rank(xShapeInfo);
int yRank = shape::rank(yShapeInfo);
auto xStrides = shape::stride(xShapeInfo);
auto zStrides = shape::stride(zShapeInfo);
Nd4jLong yStrides[4] = { 0,0,0,0 };
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3);
for (uint32_t i = start; i < stop; i++) {
uint32_t index0 = i / nSize1;
uint32_t index1 = i % nSize1;
PRAGMA_OMP_SIMD
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
for (uint32_t index3 = 0; index3 < nSize3; index3++) {
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3);
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3);
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3);
*rZ = OpType::op(*rX, *rY);
}
}
}
}
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_5D) {
int xRank = shape::rank(xShapeInfo);
int yRank = shape::rank(yShapeInfo);
auto xStrides = shape::stride(xShapeInfo);
auto zStrides = shape::stride(zShapeInfo);
Nd4jLong yStrides[5] = { 0,0,0,0,0 };
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3);
uint32_t nSize4 = shape::sizeAt(zShapeInfo, 4);
for (uint32_t i = start; i < stop; i++) {
uint32_t index0 = i / nSize1;
uint32_t index1 = i % nSize1;
PRAGMA_OMP_SIMD
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
for (uint32_t index3 = 0; index3 < nSize3; index3++) {
for (uint32_t index4 = 0; index4 < nSize4; index4++) {
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3 + xStrides[4] * index4);
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3 + yStrides[4] * index4);
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3 + zStrides[4] * index4);
*rZ = OpType::op(*rX, *rY);
}
}
}
}
}
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK];

View File

@ -73,7 +73,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
auto func = PRAGMA_THREADS_FOR {
intermediatery[thread_id] = OpType::startingIndexValue(x);
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
IndexValue<X> curr(x[i], i);
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
}
@ -88,7 +88,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
auto func = PRAGMA_THREADS_FOR {
intermediatery[thread_id] = OpType::startingIndexValue(x);
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
IndexValue<X> curr(x[offset], i);
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);

View File

@ -75,7 +75,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
}
@ -93,7 +93,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) {
for (uint64_t i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
@ -111,7 +111,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) {
for (uint64_t i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments);
@ -129,7 +129,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) {
for (uint64_t i = start; i < stop; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments);
@ -149,7 +149,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) {
for (uint64_t i = start; i < stop; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
@ -197,7 +197,7 @@ namespace functions {
else{
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) {
for (uint64_t i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments);
}
@ -213,7 +213,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) {
for (uint64_t i = start; i < stop; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments);
@ -255,7 +255,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) {
for (uint64_t i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[offset] = OpClass::op(i, length, rng, extraArguments);
}

View File

@ -88,7 +88,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
if (kindOfLoop == nd4j::LoopKind::EWS1) {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
}
};
@ -98,7 +98,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
}
@ -110,7 +110,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);

View File

@ -158,7 +158,7 @@ namespace functions {
const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeShapeInfo, tadShapeShapeInfoCast);
auto func = PRAGMA_THREADS_FOR {
for (auto r = start; r < stop; r += increment) {
for (auto r = start; r < stop; r++) {
auto tadOffsetForBlock = tadPack.primaryOffsets()[r];
auto tx = x + tadOffsetForBlock;

View File

@ -84,7 +84,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -89,7 +89,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -97,7 +97,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (Nd4jLong i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -87,7 +87,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -89,7 +89,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -81,7 +81,7 @@ namespace nd4j {
// now we actually apply quantization
auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
rz[e] = static_cast<char>(nd4j::math::nd4j_round<float, char>( 1.0f * static_cast<float>(x[e]) / nd4j::math::nd4j_max<float>(amax, amin) * max_byte));
}
};
@ -177,7 +177,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
int flimit = limit + 4;
auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) {
for (auto e = start; e < stop; e++) {
int el = x[e];
int ael = nd4j::math::nd4j_abs<int>(el) - 1;
z[ael] += el > 0 ? static_cast<T>(threshold) : static_cast<T>(-threshold);
@ -202,7 +202,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
auto z = reinterpret_cast<T *>(dz);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
for (auto i = start; i < stop; i++) {
z[i] = static_cast<T>(static_cast<float>(x[i]));
}
};

View File

@ -147,6 +147,9 @@ namespace nd4j {
// returns TRUE if this op allows in-place execution
bool allowsInplace();
// this method allows you to enable/disable inplace call for a given op
void allowInplace(bool reallyAllow);
// this method returns opNum (applicable for legacy XYZ ops only)
int getOpNum();

View File

@ -27,12 +27,10 @@ namespace nd4j {
namespace ops {
OP_IMPL(identity, 1, 1, true) {
auto first = INPUT_VARIABLE(0);
auto z = this->getZ(block);
auto z = OUTPUT_VARIABLE(0);
// just for lulz
first->applyTransform(nd4j::transform::Identity, *z);
STORE_RESULT(*z);
if (!block.isInplace())
first->applyTransform(nd4j::transform::Identity, *z);
return Status::OK();
}
@ -60,8 +58,8 @@ namespace nd4j {
DECLARE_TYPES(identity_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, DataType::ANY)
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS});
}
}
}

View File

@ -20,7 +20,7 @@
// @author Yurii Shyrma (iuriish@yahoo.com), fully rewritten
//
#include <op_boilerplate.h>
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_matmul)
#include <ops/declarable/CustomOperations.h>
@ -29,142 +29,128 @@
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
//////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
const int xRank = x->rankOf();
const int yRank = y->rankOf();
const int zRank = z->rankOf();
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
if (transZ) {
x = INPUT_VARIABLE(1);
y = INPUT_VARIABLE(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
const int xRank = x->rankOf();
const int yRank = y->rankOf();
const int zRank = z->rankOf();
const int xLastDim = transX ? -2 : -1;
const int yLastDim = transY ? -2 : -1;
const int xLastButOneDim = transX ? -1 : -2;
const int yLastButOneDim = transY ? -1 : -2;
if (transZ) {
x = INPUT_VARIABLE(1);
y = INPUT_VARIABLE(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
// ******* input validation ******* //
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0,
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
xRank, yRank);
const int xLastDim = transX ? -2 : -1;
const int yLastDim = transY ? -2 : -1;
const int xLastButOneDim = transX ? -1 : -2;
const int yLastButOneDim = transY ? -1 : -2;
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0,
"MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",
x->lengthOf(), y->lengthOf());
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0,
"MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0,
"MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else {
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0,
"MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !",
xRank, yRank, zRank);
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) &&
x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0,
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
ShapeUtils::shapeAsString(z).c_str());
// ******* input validation ******* //
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank);
if (xRank > 2) // outer dims must be the same
for (int i = 0; i < xRank - 2; ++i)
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0,
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
ShapeUtils::shapeAsString(z).c_str());
}
// ******* end of input validation ******* //
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", x->lengthOf(), y->lengthOf());
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else {
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank);
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
MmulHelper::matmul(x, y, z, transX, transY);
if (xRank > 2) // outer dims must be the same
for (int i = 0; i < xRank - 2; ++i)
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
}
// ******* end of input validation ******* //
return Status::OK();
}
MmulHelper::matmul(x, y, z, transX, transY);
DECLARE_SYN(mMul, matmul);
return Status::OK();
}
DECLARE_SYN(mmul, matmul);
DECLARE_SYN(mMul, matmul);
DECLARE_SYN(gemm, matmul);
DECLARE_SYN(mmul, matmul);
DECLARE_SYN(gemv, matmul);
DECLARE_SYN(gemm, matmul);
DECLARE_SYN(dot, matmul);
DECLARE_SYN(gemv, matmul);
DECLARE_SYN(dot, matmul);
DECLARE_SHAPE_FN(matmul) {
//////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(matmul) {
auto xShapeInfo = inputShape->at(0);
auto yShapeInfo = inputShape->at(1);
auto xShapeInfo = inputShape->at(0);
auto yShapeInfo = inputShape->at(1);
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
xShapeInfo[0], yShapeInfo[0]);
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
xShapeInfo[0], yShapeInfo[0]);
if (transZ) {
xShapeInfo = inputShape->at(1);
yShapeInfo = inputShape->at(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
if (transZ) {
xShapeInfo = inputShape->at(1);
yShapeInfo = inputShape->at(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
auto xOrder = shape::order(xShapeInfo);
auto yOrder = shape::order(yShapeInfo);
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
auto xOrder = shape::order(xShapeInfo);
auto yOrder = shape::order(yShapeInfo);
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
// we just pick the higher data type out of X and Y
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
// we just pick the higher data type out of X and Y
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
return SHAPELIST(newShape);
}
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
return SHAPELIST(newShape);
}
DECLARE_TYPES(matmul) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////
DECLARE_TYPES(matmul) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto eps = INPUT_VARIABLE(2);
auto dldx = OUTPUT_VARIABLE(0);
auto dldy = OUTPUT_VARIABLE(1);
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto eps = INPUT_VARIABLE(2);
auto dldx = OUTPUT_VARIABLE(0);
auto dldy = OUTPUT_VARIABLE(1);
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
/*
In: x=[a,b], y=[b,c]
@ -177,34 +163,35 @@ F F T [a,b] [b,c] [c,a] [c,a]
*/
nd4j::ops::matmul op;
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
nd4j::ops::matmul op;
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
return Status::OK();
}
return Status::OK();
}
//////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(matmul_bp) {
Nd4jLong *xShapeInfo;
Nd4jLong *yShapeInfo;
DECLARE_SHAPE_FN(matmul_bp) {
Nd4jLong *xShapeInfo;
Nd4jLong *yShapeInfo;
COPY_SHAPE(inputShape->at(0), xShapeInfo);
COPY_SHAPE(inputShape->at(1), yShapeInfo);
COPY_SHAPE(inputShape->at(0), xShapeInfo);
COPY_SHAPE(inputShape->at(1), yShapeInfo);
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
}
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
}
//////////////////////////////////////////////////////////////////////
DECLARE_TYPES(matmul_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS})
->setAllowedOutputTypes(1, {ALL_FLOATS});
}
DECLARE_TYPES(matmul_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS})
->setAllowedOutputTypes(1, {ALL_FLOATS});
}
}
}
}

View File

@ -21,70 +21,174 @@
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_tensormmul)
#include <numeric>
#include <helpers/ShapeUtils.h>
#include <ops/declarable/CustomOperations.h>
#include <MmulHelper.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1);
namespace ops {
auto c = OUTPUT_VARIABLE(0); //
////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1);
// building axes
int axe0_size = INT_ARG(0);
int axe1_size = INT_ARG(axe0_size+1);
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
for (int e = 0; e < axe0_size; e++)
axes_0[e] = (int) INT_ARG(e+1);
auto c = OUTPUT_VARIABLE(0);
for (int e = 0; e < axe1_size; e++)
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
// building axes
int axe0_size = INT_ARG(0);
int axe1_size = INT_ARG(axe0_size+1);
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
for (int e = 0; e < axe0_size; e++)
axes_0[e] = (int)INT_ARG(e + 1);
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
return Status::OK();
}
DECLARE_SYN(tensordot, tensormmul);
for (int e = 0; e < axe1_size; e++)
axes_1[e] = (int)INT_ARG(e + axe0_size + 2);
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
DECLARE_SHAPE_FN(tensormmul) {
auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1);
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
return Status::OK();
}
DECLARE_SYN(tensordot, tensormmul);
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(tensormmul) {
// building axes
int axe0_size = INT_ARG(0);
int axe1_size = INT_ARG(axe0_size+1);
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
for (int e = 0; e < axe0_size; e++)
axes_0[e] = (int) INT_ARG(e+1);
auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1);
for (int e = 0; e < axe1_size; e++)
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
// evaluate shapes
std::vector<int> permutAt, permutBt;
std::vector<Nd4jLong> shapeAt, shapeBt;
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
// building axes
int axe0_size = INT_ARG(0);
int axe1_size = INT_ARG(axe0_size+1);
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
for (int e = 0; e < axe0_size; e++)
axes_0[e] = (int) INT_ARG(e+1);
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
}
for (int e = 0; e < axe1_size; e++)
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
DECLARE_TYPES(tensormmul) {
getOpDescriptor()
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
}
// evaluate shapes
std::vector<int> permutAt, permutBt;
std::vector<Nd4jLong> shapeAt, shapeBt;
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
}
////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(tensormmul) {
getOpDescriptor()
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
}
////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
auto A = INPUT_VARIABLE(0);
auto B = INPUT_VARIABLE(1);
auto dLdC = INPUT_VARIABLE(2);
auto dLdA = OUTPUT_VARIABLE(0);
auto dLdB = OUTPUT_VARIABLE(1);
REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
int axe0Size = INT_ARG(0);
int axe1Size = INT_ARG(axe0Size + 1);
auto Arank = A->rankOf();
auto Brank = B->rankOf();
auto dLdCrank = dLdC->rankOf();
REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0");
REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1");
// building axes
std::vector<int> axes0(axe0Size), axes1(axe1Size);
for (uint e = 0; e < axe0Size; e++)
axes0[e] = (int)INT_ARG(e + 1);
for (uint e = 0; e < axe1Size; e++)
axes1[e] = (int)INT_ARG(e + axe0Size + 2);
std::vector<int> permutAt, permutBt;
std::vector<Nd4jLong> shapeAt, shapeBt;
ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt);
// special case for scalar value
if (dLdC->isScalar()) {
dLdA->assign((*dLdC) * *B);
dLdB->assign((*dLdC) * *A);
return Status::OK();
}
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
// rank always have to be divided by 2
std::vector<int> axesAdLdC, axesBdLdC;
if (dLdCrank > 1) {
axesAdLdC.resize(dLdCrank / 2);
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
}
else {
axesAdLdC.push_back(0);
axesBdLdC.push_back(0);
}
// calculate dLdA
MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt);
// calculate dLdB
MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt);
return Status::OK();
}
////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(tensormmul_bp) {
auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1);
auto dLShapeInfo = inputShape->at(2);
REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) &&
(ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
Nd4jLong* dLdAShapeInfo = nullptr;
Nd4jLong* dLdBShapeInfo = nullptr;
COPY_SHAPE(aShapeInfo, dLdAShapeInfo);
COPY_SHAPE(bShapeInfo, dLdBShapeInfo);
return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo));
}
////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(tensormmul_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS
->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF });
}
}
}
#endif

View File

@ -79,7 +79,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
}
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput, false);
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
nd4j::ops::conv2d conv2d;
@ -216,10 +216,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
}
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput);
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput, false);
auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO);
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
nd4j::ops::conv2d_bp conv2dBP;
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});

View File

@ -239,7 +239,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
//----- calculation of gradO -----//
if(gradB) {
if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;

View File

@ -233,7 +233,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
// ----- calculation of gradB ----- //
if(gradB) {
if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}));
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;

View File

@ -243,7 +243,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradB ----- //
if(gradB) {
if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;

View File

@ -31,22 +31,17 @@ namespace nd4j {
REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf());
REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf());
REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1));
REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!",
x->sizeAt(1), w->sizeAt(0));
REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", x->sizeAt(1), w->sizeAt(0));
auto output = OUTPUT_VARIABLE(0);
//T bound = (T)0.f;
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
nd4j::ops::xw_plus_b op;
std::unique_ptr<ResultSet> result(op.evaluate({x, w, b}));
REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data.");
auto status = op.execute({x, w, b}, {output});
REQUIRE_TRUE(Status::OK() == status, 0, "relu_layer: xw_plus_b op failed on input data.");
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
auto xw = result->at(0);
xw->applyScalar(nd4j::scalar::RELU, scalar, *output);
output->applyScalar(nd4j::scalar::RELU, scalar, *output);
return Status::OK();
}

View File

@ -23,7 +23,8 @@
//#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/image_resize.h>
#include <ops/declarable/helpers/crop_and_resize.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(crop_and_resize, 4, 1, false, 0, 0) {

View File

@ -61,13 +61,13 @@ namespace nd4j {
}
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
}
DECLARE_SHAPE_FN(resize_area) {
auto shapeList = SHAPELIST();
auto shapeList = SHAPELIST();
auto in = inputShape->at(0);
Nd4jLong* outputShape;
@ -90,7 +90,7 @@ namespace nd4j {
}
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank);
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = inRank;
if (inRank == 4) {

View File

@ -62,13 +62,13 @@ namespace nd4j {
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
}
DECLARE_SHAPE_FN(resize_bicubic) {
auto shapeList = SHAPELIST();
auto shapeList = SHAPELIST();
auto in = inputShape->at(0);
Nd4jLong* outputShape;
@ -82,7 +82,7 @@ namespace nd4j {
height = newImageSize->e<int>(1);
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = inRank;
if (inRank == 4) {

View File

@ -43,7 +43,7 @@ namespace nd4j {
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
if (block.width() > 1) {
auto newImageSize = INPUT_VARIABLE(1);
@ -71,7 +71,7 @@ namespace nd4j {
}
DECLARE_SHAPE_FN(resize_bilinear) {
auto shapeList = SHAPELIST();
auto shapeList = SHAPELIST();
auto in = inputShape->at(0);
Nd4jLong* outputShape;
@ -94,7 +94,7 @@ namespace nd4j {
width = INT_ARG(0);
height = INT_ARG(1);
}
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = inRank;
if (inRank == 4) {

View File

@ -63,13 +63,13 @@ namespace nd4j {
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
}
DECLARE_SHAPE_FN(resize_nearest_neighbor) {
auto shapeList = SHAPELIST();
auto shapeList = SHAPELIST();
auto in = inputShape->at(0);
auto inRank = shape::rank(in);
Nd4jLong* outputShape;

View File

@ -47,11 +47,12 @@ namespace nd4j {
shape.insert(shape.begin() + axis, 1);
auto tmp = input->reshape(input->ordering(), shape);
output->assign(tmp);
STORE_RESULT(output);
if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) {
output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset());
} else {
auto tmp = input->reshape(input->ordering(), shape);
output->assign(tmp);
}
return Status::OK();
}

View File

@ -15,7 +15,8 @@
******************************************************************************/
//
// Created by raver119 on 29/10/17.
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <op_boilerplate.h>
@ -29,80 +30,52 @@ namespace nd4j {
//////////////////////////////////////////////////////////////////////////
// here iArgs is int vector of ordered set of dimensions to be permuted
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
auto x = INPUT_VARIABLE(0);
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
bool replace = false;
auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
std::vector<int> arguments({});
if(origArgs.size() > 0){
for (int e = 0; e < origArgs.size(); e++) {
int ax = origArgs[e];
if (ax < 0)
ax += x->rankOf();
arguments.emplace_back(ax);
}
replace = true;
} else {
for (int e = x->rankOf() - 1; e >= 0; e--)
arguments.emplace_back(e);
}
// 0D edge case
if (x->rankOf() == 0) {
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(x);
return Status::OK();
}
if(block.isInplace()) { // in-place
x->permutei(arguments);
STORE_RESULT(x);
} else {
auto output = OUTPUT_VARIABLE(0);
auto result = x->permute(arguments);
output->assign(result);
STORE_RESULT(output);
}
return Status::OK();
}
DECLARE_TYPES(permute) {
getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {ALL_INTS})
->setSameMode(true);
}
DECLARE_SHAPE_FN(permute) {
auto shapeList = SHAPELIST();
auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
if (shape::rank(inputShape->at(0)) == 0) {
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
} else if (inputShape->size() == 1 && !arguments.empty()) {
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
} else {
if(arguments.size() == 0){
//Reverse dimensions
int rank = shape::rank(inputShape->at(0));
for (int e = rank - 1; e >= 0; e--)
arguments.emplace_back(e);
}
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
}
return shapeList;
}
if (x->isEmpty()) {
REQUIRE_TRUE(z->isEmpty(), 0, "PERMUTE OP: when input is empty, output must also be empty");
return Status::OK(); //No op
}
if (block.width() == 1 && block.getIArguments()->size() == 0) {
z->assign(x->transpose());
return Status::OK();
}
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
z->assign(x->permute(permutationVector));
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(permute) {
getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {ALL_INTS})
->setSameMode(true);
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(permute) {
auto x = INPUT_VARIABLE(0);
if (block.width() == 1 && block.getIArguments()->size() == 0)
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
return SHAPELIST(outputShapeInfo);
}
}
}
#endif

View File

@ -24,254 +24,240 @@
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
//////////////////////////////////////////////////////////////////////////
// here iArgs is a vector with (optional) negative of order as first element:
// ({-order, dim1, dim2, dim3, ...})
CUSTOM_OP_IMPL(reshape, 1, 1, true, 0, -2) {
auto x = INPUT_VARIABLE(0);
namespace ops {
if (block.width() == 1) {
auto arguments = block.getIArguments();
int argsSize = arguments->size();
//Special case: empty.reshape(<other empty shape>) -> return empty
if (x->isEmpty()) {
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
return ND4J_STATUS_OK; //No op
//////////////////////////////////////////////////////////////////////////
// here iArgs is a vector with (optional) negative of order as first element:
// ({-order, dim1, dim2, dim3, ...})
CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
//Special case: empty.reshape(<other empty shape>) -> return empty
if (x->isEmpty()) {
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
return Status::OK(); //No op
}
if (block.width() == 1) {
auto arguments = block.getIArguments();
int argsSize = arguments->size();
int e = 1;
char order = (char) -(*arguments)[0];
if (order != 'c' && order != 'f') {
order = 'c'; //x->ordering();
e = 0;
}
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
std::vector<Nd4jLong> shapeNew;
int e2 = e;
for (; e < (int) arguments->size(); e++) {
if (arguments->at(e) == -1){
Nd4jLong shapeLength = 1;
for(; e2 < e; e2++){
shapeLength *= arguments->at(e2);
}
int e = 1;
char order = (char) -(*arguments)[0];
if (order != 'c' && order != 'f') {
order = 'c'; //x->ordering();
e = 0;
}
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
std::vector<Nd4jLong> shapeNew;
int e2 = e;
for (; e < (int) arguments->size(); e++) {
if (arguments->at(e) == -1){
Nd4jLong shapeLength = 1;
for(; e2 < e; e2++){
shapeLength *= arguments->at(e2);
}
for(e2 = e + 1; e2 < arguments->size(); e2++){
shapeLength *= arguments->at(e2);
}
Nd4jLong realShape = x->lengthOf() / shapeLength;
shapeNew.push_back(realShape);
}
else{
shapeNew.push_back(arguments->at(e));
}
}
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
if (Environment::getInstance()->isDebugAndVerbose()) {
nd4j_printv("Reshape: new shape", shapeNew);
}
if (block.isInplace()) {
if (x->reshapei(order, shapeNew)) {
STORE_RESULT(*x);
return ND4J_STATUS_OK;
}
} else {
auto ret = OUTPUT_VARIABLE(0);
auto xr = x->reshape(order, shapeNew);
ret->assign(xr);
STORE_RESULT(*ret);
return Status::OK();
}
} else if (block.width() == 2) {
auto s = INPUT_VARIABLE(1);
//Special case: empty.reshape(-1) -> return empty
if (x->isEmpty()) {
//REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
return Status::OK(); //No op
}
char order = 'c';
if (block.numI() > 0)
order = (char) -INT_ARG(0);
std::vector<Nd4jLong> shapeNew(s->lengthOf());
for (int e = 0; e < (int) s->lengthOf(); e++) {
auto dim = s->e<Nd4jLong >(e);
if (dim == -1){
Nd4jLong shapeLength = 1;
for(int e2 = 0; e2 < e; e2++){
shapeLength *= s->e<Nd4jLong>(e2);
}
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= s->e<Nd4jLong>(e2);
}
Nd4jLong realShape = x->lengthOf() / shapeLength;
shapeNew[e] = realShape;
}
else{
shapeNew[e] = dim;
}
}
if (Environment::getInstance()->isDebugAndVerbose()) {
nd4j_printv("Reshape: new shape", shapeNew);
}
if (block.isInplace()) {
if (x->reshapei(order, shapeNew)) {
STORE_RESULT(*x);
return Status::OK();
}
} else {
auto ret = OUTPUT_VARIABLE(0);
if (s->isEmpty()) {
// just a scalar
ret->assign(x);
} else {
auto xr = x->reshape(order, shapeNew);
ret->assign(xr);
}
return Status::OK();
for(e2 = e + 1; e2 < arguments->size(); e2++){
shapeLength *= arguments->at(e2);
}
Nd4jLong realShape = x->lengthOf() / shapeLength;
shapeNew.push_back(realShape);
}
else{
shapeNew.push_back(arguments->at(e));
}
return ND4J_STATUS_BAD_INPUT;
}
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
DECLARE_TYPES(reshape) {
getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {ALL_INTS})
->setSameMode(true);
if (Environment::getInstance()->isDebugAndVerbose()) {
nd4j_printv("Reshape: new shape", shapeNew);
}
DECLARE_SHAPE_FN(reshape) {
auto inp = inputShape->at(0);
auto xr = x->reshape(order, shapeNew);
z->assign(xr);
STORE_RESULT(*z);
// we can launch op using Int arguments
if (inputShape->size() == 1) {
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
std::vector<int> *arguments = block.getIArguments();
return Status::OK();
int e = 1;
char order = (char) -(*arguments)[0];
if (order != 'c' && order != 'f') {
order = shape::order(inp);
e = 0;
} else if (block.width() == 2) {
auto s = INPUT_VARIABLE(1);
char order = 'c';
if (block.numI() > 0)
order = (char) -INT_ARG(0);
std::vector<Nd4jLong> shapeNew(s->lengthOf());
for (int e = 0; e < (int) s->lengthOf(); e++) {
auto dim = s->e<Nd4jLong >(e);
if (dim == -1){
Nd4jLong shapeLength = 1;
for(int e2 = 0; e2 < e; e2++){
shapeLength *= s->e<Nd4jLong>(e2);
}
std::vector<Nd4jLong> shapeNew;
int e2 = e;
for (; e < (int) arguments->size(); e++) {
if ((int) arguments->at(e) == -1){
Nd4jLong shapeLength = 1;
for(; e2 < e; e2 ++){
shapeLength *= arguments->at(e2);
}
for(e2 = e + 1; e2 < arguments->size(); e2++){
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= arguments->at(e2);
}
if(shapeLength == 0){
//Edge case for empty:
shapeNew.push_back(0);
} else {
//Standard case
Nd4jLong realShape = shape::length(inp) / shapeLength;
shapeNew.push_back(realShape);
}
}
else{
shapeNew.push_back(arguments->at(e));
}
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= s->e<Nd4jLong>(e2);
}
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
} else {
// or, with second input "as shape"
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
// special case here
if (y->isEmpty()) {
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
}
//Special case: empty.reshape(-1) -> return empty
if (x->isEmpty()) {
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
Nd4jLong prod = 1;
bool hasNegs = false;
for (auto v:shapeOf) {
if (v < 0) {
hasNegs = true;
v = 0;
}
prod *= v;
}
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
// if there are -1s - we turn them into zeros
if (hasNegs) {
for (int e = 0; e < shapeOf.size(); e++)
if (shapeOf[e] < 0)
shapeOf[e] = 0;
}
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
return SHAPELIST(CONSTANT(newShape));
}
std::vector<Nd4jLong> shapeNew(y->lengthOf());
for (int e = 0; e < (int) y->lengthOf(); e++) {
auto dim = y->e<Nd4jLong>(e);
if (dim == -1){
Nd4jLong shapeLength = 1;
for(int e2 = 0; e2 < e; e2++){
shapeLength *= y->e<Nd4jLong>(e2);
}
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= y->e<Nd4jLong>(e2);
}
if(shapeLength == 0){
//Edge case for empty:
shapeNew[e] = 0;
} else {
Nd4jLong realShape = shape::length(inp) / shapeLength;
shapeNew[e] = realShape;
}
}else {
shapeNew[e] = dim;
}
}
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
Nd4jLong realShape = x->lengthOf() / shapeLength;
shapeNew[e] = realShape;
}
else{
shapeNew[e] = dim;
}
}
if (Environment::getInstance()->isDebugAndVerbose()) {
nd4j_printv("Reshape: new shape", shapeNew);
}
if (s->isEmpty()) {
// just a scalar
z->assign(x);
} else {
auto xr = x->reshape(order, shapeNew);
z->assign(xr);
}
return Status::OK();
}
return ND4J_STATUS_BAD_INPUT;
}
DECLARE_TYPES(reshape) {
getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {ALL_INTS})
->setSameMode(true);
}
DECLARE_SHAPE_FN(reshape) {
auto inp = inputShape->at(0);
// we can launch op using Int arguments
if (inputShape->size() == 1) {
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
std::vector<int> *arguments = block.getIArguments();
int e = 1;
char order = (char) -(*arguments)[0];
if (order != 'c' && order != 'f') {
order = shape::order(inp);
e = 0;
}
std::vector<Nd4jLong> shapeNew;
int e2 = e;
for (; e < (int) arguments->size(); e++) {
if ((int) arguments->at(e) == -1){
Nd4jLong shapeLength = 1;
for(; e2 < e; e2 ++){
shapeLength *= arguments->at(e2);
}
for(e2 = e + 1; e2 < arguments->size(); e2++){
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= arguments->at(e2);
}
if(shapeLength == 0){
//Edge case for empty:
shapeNew.push_back(0);
} else {
//Standard case
Nd4jLong realShape = shape::length(inp) / shapeLength;
shapeNew.push_back(realShape);
}
}
else{
shapeNew.push_back(arguments->at(e));
}
}
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
} else {
// or, with second input "as shape"
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
// special case here
if (y->isEmpty()) {
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
}
//Special case: empty.reshape(-1) -> return empty
if (x->isEmpty()) {
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
Nd4jLong prod = 1;
bool hasNegs = false;
for (auto v:shapeOf) {
if (v < 0) {
hasNegs = true;
v = 0;
}
prod *= v;
}
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
// if there are -1s - we turn them into zeros
if (hasNegs) {
for (int e = 0; e < shapeOf.size(); e++)
if (shapeOf[e] < 0)
shapeOf[e] = 0;
}
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
return SHAPELIST(CONSTANT(newShape));
}
std::vector<Nd4jLong> shapeNew(y->lengthOf());
for (int e = 0; e < (int) y->lengthOf(); e++) {
auto dim = y->e<Nd4jLong>(e);
if (dim == -1){
Nd4jLong shapeLength = 1;
for(int e2 = 0; e2 < e; e2++){
shapeLength *= y->e<Nd4jLong>(e2);
}
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= y->e<Nd4jLong>(e2);
}
if(shapeLength == 0){
//Edge case for empty:
shapeNew[e] = 0;
} else {
Nd4jLong realShape = shape::length(inp) / shapeLength;
shapeNew[e] = realShape;
}
}else {
shapeNew[e] = dim;
}
}
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
}
}
}
}
#endif

View File

@ -28,35 +28,27 @@ namespace nd4j {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(reshapeas, 2, 1, true, 0, 0) {
CUSTOM_OP_IMPL(reshapeas, 2, 1, false, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
std::vector<Nd4jLong> shapeNew(y->shapeOf(), y->shapeOf() + y->rankOf());
char order = y->ordering();
if (x->reshapei(order, shapeNew)) {
*z = *x;
STORE_RESULT(*z);
if (x->reshapei(y->ordering(), y->getShapeAsVector())) {
z->assign(x);
return Status::OK();
}
return ND4J_STATUS_BAD_INPUT;
}
DECLARE_SYN(reshape_as, reshapeas);
DECLARE_SHAPE_FN(reshapeas) {
auto inputShapeInfo = inputShape->at(1);
int shapeInfoLength = inputShapeInfo[0]*2 + 4;
Nd4jLong* outputShapeInfo(nullptr);
COPY_SHAPE(inputShapeInfo, outputShapeInfo);
return SHAPELIST(CONSTANT(outputShapeInfo));
}
DECLARE_SHAPE_FN(reshapeas) {
return SHAPELIST(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->getShapeInfo(), false, block.workspace()));
}
DECLARE_TYPES(reshapeas) {
getOpDescriptor()

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(squeeze, 1, 1, true, 0, -2) {
CUSTOM_OP_IMPL(squeeze, 1, 1, false, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
@ -36,14 +36,14 @@ namespace nd4j {
int _a = INT_ARG(e);
if (_a < 0)
_a += input->rankOf();
axis.emplace_back(_a);
}
else if (block.width() > 1) {
auto a = INPUT_VARIABLE(1);
for (Nd4jLong e = 0; e < a->lengthOf(); e++) {
int _a = a->e<int>(e);
if (_a < 0)
_a += input->rankOf();
@ -71,10 +71,14 @@ namespace nd4j {
}
if (block.isInplace()) {
output->reshapei(input->ordering(), shape);
output->reshapei(input->ordering(), shape, false);
} else {
auto tmp = input->reshape(input->ordering(), shape);
output->assign(tmp);
if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) {
output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset());
} else {
auto tmp = input->reshape(input->ordering(), shape);
output->assign(tmp);
}
}
return Status::OK();
@ -106,20 +110,20 @@ namespace nd4j {
int _a = INT_ARG(e);
if (_a < 0)
_a += rank;
axis.emplace_back(_a);
}
else if (block.width() > 1) {
auto a = INPUT_VARIABLE(1);
for (int e = 0; e < a->lengthOf(); e++) {
int _a = a->e<int>(e);
if (_a < 0)
_a += rank;
axis.emplace_back(_a);
}
}
auto order = shape::order(in);

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(tile_to_shape, 1, 1, true, 0, -1) {
CUSTOM_OP_IMPL(tile_to_shape, 1, 1, false, 0, -1) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);

View File

@ -15,7 +15,8 @@
******************************************************************************/
//
// Created by raver119 on 29/10/17.
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <op_boilerplate.h>
@ -25,113 +26,52 @@
#include <helpers/ShapeUtils.h>
namespace nd4j {
namespace ops {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(transpose, 1, 1, true, 0, 0) {
auto x = INPUT_VARIABLE(0);
if (block.width() == 1) {
if (block.isInplace()) {
x->transposei();
STORE_RESULT(*x);
} else {
auto output = OUTPUT_VARIABLE(0);
auto t = x->transpose();
output->assign(t);
STORE_RESULT(*output);
}
} else {
// this is tf-mode transpose, that's nd4j permute
bool replace = false;
std::vector<int> arguments(*block.getIArguments());
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) {
auto w = block.width();
auto a = arguments.size();
auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
if (w == 2 && a == 0) {
auto axis = INPUT_VARIABLE(1);
for (int e = 0; e < axis->lengthOf(); e++) {
auto ax = axis->e<int>(e);
if (ax < 0)
ax += x->rankOf();
//Special case: empty.reshape(<other empty shape>) -> return empty
if (x->isEmpty()) {
REQUIRE_TRUE(z->isEmpty(), 0, "TRANSPOSE OP: when input is empty, output must also be empty");
return Status::OK(); //No op
}
arguments.emplace_back(ax);
}
replace = true;
} else if (a == 0) {
for (int e = x->rankOf() - 1; e >= 0; e--)
arguments.emplace_back(e);
}
// 0D edge case
if (x->rankOf() == 0) {
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(x);
return Status::OK();
}
if(block.isInplace()) { // in-place
x->permutei(arguments);
STORE_RESULT(x);
} else {
auto input = x->permute(arguments);
auto output = OUTPUT_VARIABLE(0);
output->assign(input);
}
}
if (block.width() == 1 && block.getIArguments()->size() == 0) {
z->assign(x->transpose());
return Status::OK();
}
DECLARE_TYPES(transpose) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setSameMode(true);
}
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
DECLARE_SHAPE_FN(transpose) {
if (block.width() == 1) {
auto outputShapeInfo = ShapeUtils::evalTranspShapeInfo(*INPUT_VARIABLE(0), block.workspace());
return SHAPELIST(outputShapeInfo);
} else {
// this is basically permute mode
auto shapeList = SHAPELIST();
auto arguments = block.getIArguments();
if (shape::rank(inputShape->at(0)) == 0) {
Nd4jLong *newshape;
ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape->at(0)), Nd4jLong);
newshape[0] = 0;
newshape[1] = 0;
newshape[2] = 1;
newshape[3] = 99;
ArrayOptions::copyDataType(newshape, inputShape->at(0));
shapeList->push_back(newshape);
} else if (arguments->size() > 0 || inputShape->size() > 1) {
auto axis = arguments->size() > 0 ? *arguments : (INPUT_VARIABLE(1))->template asVectorT<int>();
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(axis.data(), axis.size(), *INPUT_VARIABLE(0), block.workspace());
shapeList->push_back(outputShapeInfo);
} else if (inputShape->size() == 2) {
// dead end
auto axis = INPUT_VARIABLE(1);
auto axisV = axis->template asVectorT<Nd4jLong>();
auto newshape = ShapeUtils::evalPermShapeInfo(axisV.data(), axisV.size(), *INPUT_VARIABLE(0), block.workspace());
shapeList->push_back(newshape);
} else {
int rank = shape::rank(inputShape->at(0));
for (int e = rank - 1; e >= 0; e--)
arguments->emplace_back(e);
z->assign(x->permute(permutationVector));
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace());
shapeList->push_back(outputShapeInfo);
}
return Status::OK();
}
DECLARE_TYPES(transpose) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setSameMode(true);
}
DECLARE_SHAPE_FN(transpose) {
auto x = INPUT_VARIABLE(0);
if (block.width() == 1 && block.getIArguments()->size() == 0)
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
return SHAPELIST(outputShapeInfo);
}
return shapeList;
}
}
}
}

Some files were not shown because too many files have changed in this diff Show More