Merge pull request #8723 from KonduitAI/master

Merge recent development updates
master
Alex Black 2020-02-21 20:00:35 +11:00 committed by GitHub
commit e4ddf109c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
239 changed files with 7636 additions and 4122 deletions

View File

@ -1,47 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.util;
import org.slf4j.Logger;
import java.awt.*;
import java.net.URI;
/**
* Various utilities for webpages and dealing with browsers
*/
public class WebUtils {
public static void tryOpenBrowser(String path, Logger log) {
try {
WebUtils.openBrowser(new URI(path));
} catch (Exception e) {
log.error("Could not open browser", e);
System.out.println("Browser could not be launched automatically.\nUI path: " + path);
}
}
public static void openBrowser(URI uri) throws Exception {
if (Desktop.isDesktopSupported()) {
Desktop.getDesktop().browse(uri);
} else {
throw new UnsupportedOperationException(
"Cannot open browser on this platform: Desktop.isDesktopSupported() == false");
}
}
}

View File

@ -127,7 +127,7 @@ public class BraninFunction {
BraninConfig candidate = (BraninConfig) c.getValue(); BraninConfig candidate = (BraninConfig) c.getValue();
double score = scoreFunction.score(candidate, null, (Map) null); double score = scoreFunction.score(candidate, null, (Map) null);
System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score); // System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
Thread.sleep(20); Thread.sleep(20);

View File

@ -54,7 +54,7 @@ public class TestRandomSearch extends BaseDL4JTest {
runner.execute(); runner.execute();
System.out.println("----- Complete -----"); // System.out.println("----- Complete -----");
} }

View File

@ -16,8 +16,8 @@
package org.deeplearning4j.arbiter.optimize.genetic; package org.deeplearning4j.arbiter.optimize.genetic;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.RandomGenerator;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class TestRandomGenerator implements RandomGenerator { public class TestRandomGenerator implements RandomGenerator {
private final int[] intRandomNumbers; private final int[] intRandomNumbers;
@ -63,17 +63,17 @@ public class TestRandomGenerator implements RandomGenerator {
@Override @Override
public long nextLong() { public long nextLong() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
@Override @Override
public boolean nextBoolean() { public boolean nextBoolean() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
@Override @Override
public float nextFloat() { public float nextFloat() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
@Override @Override
@ -83,6 +83,6 @@ public class TestRandomGenerator implements RandomGenerator {
@Override @Override
public double nextGaussian() { public double nextGaussian() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover; package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
@ -26,7 +27,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest { public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
@ -42,7 +42,7 @@ public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
@Override @Override
public CrossoverResult crossover() { public CrossoverResult crossover() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.culling; package org.deeplearning4j.arbiter.optimize.genetic.culling;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.population.Populati
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import java.util.List; import java.util.List;
@ -46,7 +46,7 @@ public class RatioCullOperatorTests extends BaseDL4JTest {
@Override @Override
public void cullPopulation() { public void cullPopulation() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
public double getCullRatio() { public double getCullRatio() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection; package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
@ -33,7 +34,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
@ -55,7 +55,7 @@ public class GeneticSelectionOperatorTests extends BaseDL4JTest {
@Override @Override
public void cullPopulation() { public void cullPopulation() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
@Override @Override

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection; package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.Selection
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class SelectionOperatorTests extends BaseDL4JTest { public class SelectionOperatorTests extends BaseDL4JTest {
private class TestSelectionOperator extends SelectionOperator { private class TestSelectionOperator extends SelectionOperator {
@ -39,7 +39,7 @@ public class SelectionOperatorTests extends BaseDL4JTest {
@Override @Override
public double[] buildNextGenes() { public double[] buildNextGenes() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
} }

View File

@ -158,7 +158,7 @@ public class TestComputationGraphSpace extends BaseDL4JTest {
} }
} }
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount); // System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
assertTrue(reluCount > 0); assertTrue(reluCount > 0);
assertTrue(tanhCount > 0); assertTrue(tanhCount > 0);

View File

@ -162,7 +162,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
List<ResultReference> results = runner.getResults(); List<ResultReference> results = runner.getResults();
assertTrue(results.size() > 0); assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----"); // System.out.println("----- COMPLETE - " + results.size() + " results -----");
} }
} }

View File

@ -165,7 +165,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
List<ResultReference> results = runner.getResults(); List<ResultReference> results = runner.getResults();
assertTrue(results.size() > 0); assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----"); // System.out.println("----- COMPLETE - " + results.size() + " results -----");
} }
} }

View File

@ -101,7 +101,7 @@ public class TestLayerSpace extends BaseDL4JTest {
double l2 = TestUtils.getL2(l); double l2 = TestUtils.getL2(l);
IActivation activation = l.getActivationFn(); IActivation activation = l.getActivationFn();
System.out.println(lr + "\t" + l2 + "\t" + activation); // System.out.println(lr + "\t" + l2 + "\t" + activation);
assertTrue(lr >= 0.3 && lr <= 0.4); assertTrue(lr >= 0.3 && lr <= 0.4);
assertTrue(l2 >= 0.01 && l2 <= 0.1); assertTrue(l2 >= 0.01 && l2 <= 0.1);
@ -190,7 +190,7 @@ public class TestLayerSpace extends BaseDL4JTest {
ActivationLayer al = als.getValue(d); ActivationLayer al = als.getValue(d);
IActivation activation = al.getActivationFn(); IActivation activation = al.getActivationFn();
System.out.println(activation); // System.out.println(activation);
assertTrue(containsActivationFunction(actFns, activation)); assertTrue(containsActivationFunction(actFns, activation));
} }
@ -228,7 +228,7 @@ public class TestLayerSpace extends BaseDL4JTest {
IActivation activation = el.getActivationFn(); IActivation activation = el.getActivationFn();
long nOut = el.getNOut(); long nOut = el.getNOut();
System.out.println(activation + "\t" + nOut); // System.out.println(activation + "\t" + nOut);
assertTrue(containsActivationFunction(actFns, activation)); assertTrue(containsActivationFunction(actFns, activation));
assertTrue(nOut >= 10 && nOut <= 20); assertTrue(nOut >= 10 && nOut <= 20);
@ -295,7 +295,7 @@ public class TestLayerSpace extends BaseDL4JTest {
long nOut = el.getNOut(); long nOut = el.getNOut();
double forgetGate = el.getForgetGateBiasInit(); double forgetGate = el.getForgetGateBiasInit();
System.out.println(activation + "\t" + nOut + "\t" + forgetGate); // System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
assertTrue(containsActivationFunction(actFns, activation)); assertTrue(containsActivationFunction(actFns, activation));
assertTrue(nOut >= 10 && nOut <= 20); assertTrue(nOut >= 10 && nOut <= 20);

View File

@ -293,8 +293,8 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly
} }
System.out.println("Number of layers: " + Arrays.toString(nLayerCounts)); // System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount); // System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
} }

View File

@ -98,7 +98,8 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson())); assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson()));
FileUtils.writeStringToFile(new File(configPath),configuration.toJson()); FileUtils.writeStringToFile(new File(configPath),configuration.toJson());
System.out.println(configuration.toJson()); // System.out.println(configuration.toJson());
configuration.toJson();
log.info("Starting test"); log.info("Starting test");
cliRunner.runMain( cliRunner.runMain(

View File

@ -0,0 +1,390 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.deeplearning4j.regressiontest;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.impl.MergeVertex;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer;
import org.junit.Test;
import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.lossfunctions.impl.LossMAE;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.resources.Resources;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import static org.junit.Assert.*;
public class RegressionTest100b6 extends BaseDL4JTest {
@Override
public DataType getDataType() {
return DataType.FLOAT;
}
@Test
public void testCustomLayer() throws Exception {
for (DataType dtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
String dtypeName = dtype.toString().toLowerCase();
File f = Resources.asFile("regression_testing/100b6/CustomLayerExample_100b6_" + dtypeName + ".bin");
MultiLayerNetwork.load(f, true);
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
// net = net.clone();
DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer();
assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
assertEquals(new RmsProp(0.95), l0.getIUpdater());
CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer();
assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
assertEquals(new RmsProp(0.95), l1.getIUpdater());
INDArray outExp;
File f2 = Resources
.asFile("regression_testing/100b6/CustomLayerExample_Output_100b6_" + dtypeName + ".bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/CustomLayerExample_Input_100b6_" + dtypeName + ".bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
assertEquals(dtype, in.dataType());
assertEquals(dtype, outExp.dataType());
assertEquals(dtype, net.params().dataType());
assertEquals(dtype, net.getFlattenedGradients().dataType());
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
//System.out.println(Arrays.toString(net.params().data().asFloat()));
INDArray outAct = net.output(in);
assertEquals(dtype, outAct.dataType());
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
assertEquals(dtype, net.params().dataType());
boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue(outExp + " vs " + outAct, eq); }
}
@Test
public void testLSTM() throws Exception {
File f = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_100b6.bin");
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
LSTM l0 = (LSTM) net.getLayer(0).conf().getLayer();
assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater());
LSTM l1 = (LSTM) net.getLayer(1).conf().getLayer();
assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).conf().getLayer();
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
assertEquals(new Adam(0.005), l2.getIUpdater());
assertEquals(BackpropType.TruncatedBPTT, net.getLayerWiseConfigurations().getBackpropType());
assertEquals(50, net.getLayerWiseConfigurations().getTbpttBackLength());
assertEquals(50, net.getLayerWiseConfigurations().getTbpttFwdLength());
INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Output_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Input_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
INDArray outAct = net.output(in);
assertEquals(outExp, outAct);
}
@Test
public void testVae() throws Exception {
File f = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_100b6.bin");
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).conf().getLayer();
assertEquals(new ActivationLReLU(), l0.getActivationFn());
assertEquals(32, l0.getNOut());
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater());
INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Output_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Input_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
INDArray outAct = net.output(in);
assertEquals(outExp, outAct);
}
@Test
public void testYoloHouseNumber() throws Exception {
File f = Resources.asFile("regression_testing/100b6/HouseNumberDetection_100b6.bin");
ComputationGraph net = ComputationGraph.load(f, true);
int nBoxes = 5;
int nClasses = 10;
ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getConfiguration().getVertices()
.get("convolution2d_9")).getLayerConf().getLayer();
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
assertEquals(new ActivationIdentity(), cl.getActivationFn());
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/HouseNumberDetection_Output_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/HouseNumberDetection_Input_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
INDArray outAct = net.outputSingle(in);
boolean eq = outExp.equalsWithEps(outAct.castTo(outExp.dataType()), 1e-3);
assertTrue(eq);
}
@Test
public void testSyntheticCNN() throws Exception {
File f = Resources.asFile("regression_testing/100b6/SyntheticCNN_100b6.bin");
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).conf().getLayer();
assertEquals(new ActivationReLU(), l0.getActivationFn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
assertArrayEquals(new int[]{2, 1}, l0.getStride());
assertArrayEquals(new int[]{1, 1}, l0.getDilation());
assertArrayEquals(new int[]{0, 0}, l0.getPadding());
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).conf().getLayer();
assertEquals(new ActivationReLU(), l1.getActivationFn());
assertEquals(8, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l1.getStride());
assertArrayEquals(new int[]{1, 1}, l1.getDilation());
assertArrayEquals(new int[]{0, 0}, l1.getPadding());
assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
assertEquals(1, l1.getDepthMultiplier());
SubsamplingLayer l2 = (SubsamplingLayer) net.getLayer(2).conf().getLayer();
assertArrayEquals(new int[]{3, 3}, l2.getKernelSize());
assertArrayEquals(new int[]{2, 2}, l2.getStride());
assertArrayEquals(new int[]{1, 1}, l2.getDilation());
assertArrayEquals(new int[]{0, 0}, l2.getPadding());
assertEquals(PoolingType.MAX, l2.getPoolingType());
ZeroPaddingLayer l3 = (ZeroPaddingLayer) net.getLayer(3).conf().getLayer();
assertArrayEquals(new int[]{4, 4, 4, 4}, l3.getPadding());
Upsampling2D l4 = (Upsampling2D) net.getLayer(4).conf().getLayer();
assertArrayEquals(new int[]{3, 3}, l4.getSize());
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).conf().getLayer();
assertEquals(new ActivationReLU(), l5.getActivationFn());
assertEquals(16, l5.getNOut());
assertEquals(new WeightInitXavier(), l5.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
assertEquals(new Adam(0.005), l5.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l5.getStride());
assertArrayEquals(new int[]{1, 1}, l5.getDilation());
assertArrayEquals(new int[]{0, 0}, l5.getPadding());
assertEquals(2, l5.getDepthMultiplier());
SubsamplingLayer l6 = (SubsamplingLayer) net.getLayer(6).conf().getLayer();
assertArrayEquals(new int[]{2, 2}, l6.getKernelSize());
assertArrayEquals(new int[]{2, 2}, l6.getStride());
assertArrayEquals(new int[]{1, 1}, l6.getDilation());
assertArrayEquals(new int[]{0, 0}, l6.getPadding());
assertEquals(PoolingType.MAX, l6.getPoolingType());
Cropping2D l7 = (Cropping2D) net.getLayer(7).conf().getLayer();
assertArrayEquals(new int[]{3, 3, 2, 2}, l7.getCropping());
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).conf().getLayer();
assertEquals(4, l8.getNOut());
assertEquals(new WeightInitXavier(), l8.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
assertEquals(new Adam(0.005), l8.getIUpdater());
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l8.getStride());
assertArrayEquals(new int[]{1, 1}, l8.getDilation());
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).conf().getLayer();
assertEquals(new WeightInitXavier(), l9.getWeightInitFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
assertEquals(new Adam(0.005), l9.getIUpdater());
assertEquals(new LossMAE(), l9.getLossFn());
INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/SyntheticCNN_Output_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/SyntheticCNN_Input_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
INDArray outAct = net.output(in);
//19 layers - CPU vs. GPU difference accumulates notably, but appears to be correct
if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")){
assertEquals(outExp, outAct);
} else {
boolean eq = outExp.equalsWithEps(outAct, 0.1);
assertTrue(eq);
}
}
@Test
public void testSyntheticBidirectionalRNNGraph() throws Exception {
File f = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_100b6.bin");
ComputationGraph net = ComputationGraph.load(f, true);
Bidirectional l0 = (Bidirectional) net.getLayer("rnn1").conf().getLayer();
LSTM l1 = (LSTM) l0.getFwd();
assertEquals(16, l1.getNOut());
assertEquals(new ActivationReLU(), l1.getActivationFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
LSTM l2 = (LSTM) l0.getBwd();
assertEquals(16, l2.getNOut());
assertEquals(new ActivationReLU(), l2.getActivationFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
Bidirectional l3 = (Bidirectional) net.getLayer("rnn2").conf().getLayer();
SimpleRnn l4 = (SimpleRnn) l3.getFwd();
assertEquals(16, l4.getNOut());
assertEquals(new ActivationReLU(), l4.getActivationFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l4));
SimpleRnn l5 = (SimpleRnn) l3.getBwd();
assertEquals(16, l5.getNOut());
assertEquals(new ActivationReLU(), l5.getActivationFn());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
MergeVertex mv = (MergeVertex) net.getVertex("concat");
GlobalPoolingLayer gpl = (GlobalPoolingLayer) net.getLayer("pooling").conf().getLayer();
assertEquals(PoolingType.MAX, gpl.getPoolingType());
assertArrayEquals(new int[]{2}, gpl.getPoolingDimensions());
assertTrue(gpl.isCollapseDimensions());
OutputLayer outl = (OutputLayer) net.getLayer("out").conf().getLayer();
assertEquals(3, outl.getNOut());
assertEquals(new LossMCXENT(), outl.getLossFn());
INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_Output_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
outExp = Nd4j.read(dis);
}
INDArray in;
File f3 = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_Input_100b6.bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
in = Nd4j.read(dis);
}
INDArray outAct = net.output(in)[0];
assertEquals(outExp, outAct);
}
}

View File

@ -41,7 +41,7 @@ public class TestGraphLoading extends BaseDL4JTest {
IGraph<String, String> graph = GraphLoader IGraph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ","); .loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
System.out.println(graph); // System.out.println(graph);
assertEquals(graph.numVertices(), 7); assertEquals(graph.numVertices(), 7);
int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}}; int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}};
@ -66,7 +66,7 @@ public class TestGraphLoading extends BaseDL4JTest {
edgeLineProcessor, vertexFactory, 10, false); edgeLineProcessor, vertexFactory, 10, false);
System.out.println(graph); // System.out.println(graph);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
List<Edge<String>> edges = graph.getEdgesOut(i); List<Edge<String>> edges = graph.getEdgesOut(i);
@ -111,7 +111,7 @@ public class TestGraphLoading extends BaseDL4JTest {
Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(), Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(),
edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false); edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
System.out.println(graph); // System.out.println(graph);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
List<Edge<String>> edges = graph.getEdgesOut(i); List<Edge<String>> edges = graph.getEdgesOut(i);

View File

@ -71,7 +71,7 @@ public class TestGraphLoadingWeighted extends BaseDL4JTest {
} }
} }
System.out.println(graph); // System.out.println(graph);
} }

View File

@ -220,7 +220,7 @@ public class TestGraph extends BaseDL4JTest {
sum += transitionProb[i][j]; sum += transitionProb[i][j];
for (int j = 0; j < transitionProb[i].length; j++) for (int j = 0; j < transitionProb[i].length; j++)
transitionProb[i][j] /= sum; transitionProb[i][j] /= sum;
System.out.println(Arrays.toString(transitionProb[i])); // System.out.println(Arrays.toString(transitionProb[i]));
} }
//Check that transition probs are essentially correct (within bounds of random variation) //Check that transition probs are essentially correct (within bounds of random variation)

View File

@ -145,8 +145,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR) if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
fail(msg); fail(msg);
else // else
System.out.println(msg); // System.out.println(msg);
} }
} }
@ -333,10 +333,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR) if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
fail(msg); fail(msg);
else // else
System.out.println(msg); // System.out.println(msg);
} }
System.out.println(); // System.out.println();
} }
} }

View File

@ -67,7 +67,7 @@ public class TestDeepWalk extends BaseDL4JTest {
for (int i = 0; i < 7; i++) { for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i); INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new long[] {vectorSize}, vector.shape()); assertArrayEquals(new long[] {vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat())); // System.out.println(Arrays.toString(vector.dup().data().asFloat()));
} }
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8); GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
@ -77,11 +77,11 @@ public class TestDeepWalk extends BaseDL4JTest {
for (int t = 0; t < 5; t++) { for (int t = 0; t < 5; t++) {
iter.reset(); iter.reset();
deepWalk.fit(iter); deepWalk.fit(iter);
System.out.println("--------------------"); // System.out.println("--------------------");
for (int i = 0; i < 7; i++) { for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i); INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new long[] {vectorSize}, vector.shape()); assertArrayEquals(new long[] {vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat())); // System.out.println(Arrays.toString(vector.dup().data().asFloat()));
} }
} }
} }
@ -160,7 +160,7 @@ public class TestDeepWalk extends BaseDL4JTest {
continue; continue;
double sim = deepWalk.similarity(i, nearestTo); double sim = deepWalk.similarity(i, nearestTo);
System.out.println(i + "\t" + nearestTo + "\t" + sim); // System.out.println(i + "\t" + nearestTo + "\t" + sim);
assertTrue(sim <= minSimNearest); assertTrue(sim <= minSimNearest);
} }
} }
@ -211,7 +211,7 @@ public class TestDeepWalk extends BaseDL4JTest {
Graph<String, String> graph = GraphLoader Graph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ","); .loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
System.out.println(graph); // System.out.println(graph);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -229,11 +229,13 @@ public class TestDeepWalk extends BaseDL4JTest {
//Calculate similarity(0,i) //Calculate similarity(0,i)
for (int i = 0; i < nVertices; i++) { for (int i = 0; i < nVertices; i++) {
System.out.println(deepWalk.similarity(0, i)); // System.out.println(deepWalk.similarity(0, i));
deepWalk.similarity(0, i);
} }
for (int i = 0; i < nVertices; i++) for (int i = 0; i < nVertices; i++)
System.out.println(deepWalk.getVertexVector(i)); // System.out.println(deepWalk.getVertexVector(i));
deepWalk.getVertexVector(i);
} }
@Test(timeout = 60000L) @Test(timeout = 60000L)

View File

@ -38,9 +38,11 @@ public class TestGraphHuffman extends BaseDL4JTest {
gh.buildTree(vertexDegrees); gh.buildTree(vertexDegrees);
for (int i = 0; i < 7; i++) for (int i = 0; i < 7; i++) {
System.out.println(i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i) String s = i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i))); + "\t\t" + Arrays.toString(gh.getPathInnerNodes(i));
// System.out.println(s);
}
int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5}; int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5};
for (int i = 0; i < vertexDegrees.length; i++) { for (int i = 0; i < vertexDegrees.length; i++) {

View File

@ -3,6 +3,7 @@ package org.deeplearning4j.util;
import lombok.NonNull; import lombok.NonNull;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -121,7 +122,7 @@ public class DL4JModelValidator {
} }
try{ try{
MultiLayerConfiguration.fromJson(config); ComputationGraphConfiguration.fromJson(config);
} catch (Throwable t){ } catch (Throwable t){
return ValidationResult.builder() return ValidationResult.builder()
.formatType("ComputationGraph") .formatType("ComputationGraph")

View File

@ -79,8 +79,9 @@ public class ParameterServerParallelWrapperTest extends BaseDL4JTest {
model.init(); model.init();
ParallelWrapper parameterServerParallelWrapper = ParallelWrapper parameterServerParallelWrapper =
new ParallelWrapper.Builder(model).trainerFactory(new ParameterServerTrainerContext()) new ParallelWrapper.Builder(model)
.workers(Runtime.getRuntime().availableProcessors()) .workers(Math.min(4, Runtime.getRuntime().availableProcessors()))
.trainerFactory(new ParameterServerTrainerContext())
.reportScoreAfterAveraging(true).prefetchBuffer(3).build(); .reportScoreAfterAveraging(true).prefetchBuffer(3).build();
parameterServerParallelWrapper.fit(mnistTrain); parameterServerParallelWrapper.fit(mnistTrain);

View File

@ -104,7 +104,7 @@ public class SparkWord2VecTest extends BaseDL4JTest {
public void call(ExportContainer<VocabWord> v) throws Exception { public void call(ExportContainer<VocabWord> v) throws Exception {
assertNotNull(v.getElement()); assertNotNull(v.getElement());
assertNotNull(v.getArray()); assertNotNull(v.getArray());
System.out.println(v.getElement() + " - " + v.getArray()); // System.out.println(v.getElement() + " - " + v.getArray());
} }
} }
} }

View File

@ -66,7 +66,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -119,7 +119,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MSE).build()) .lossFunction(LossFunctions.LossFunction.MSE).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
@ -155,7 +155,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -198,7 +198,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -231,7 +231,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();

View File

@ -69,7 +69,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -120,7 +120,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MSE).build(), "in") .lossFunction(LossFunctions.LossFunction.MSE).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
@ -158,7 +158,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -203,7 +203,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -238,7 +238,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();

View File

@ -59,7 +59,7 @@ public class TestShuffleExamples extends BaseSparkTest {
int totalExampleCount = 0; int totalExampleCount = 0;
for (DataSet ds : shuffledList) { for (DataSet ds : shuffledList) {
totalExampleCount += ds.getFeatures().length(); totalExampleCount += ds.getFeatures().length();
System.out.println(Arrays.toString(ds.getFeatures().data().asFloat())); // System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
assertEquals(ds.getFeatures(), ds.getLabels()); assertEquals(ds.getFeatures(), ds.getLabels());
} }

View File

@ -86,7 +86,7 @@ public class TestExport extends BaseSparkTest {
for (File file : files) { for (File file : files) {
if (!file.getPath().endsWith(".bin")) if (!file.getPath().endsWith(".bin"))
continue; continue;
System.out.println(file); // System.out.println(file);
DataSet ds = new DataSet(); DataSet ds = new DataSet();
ds.load(file); ds.load(file);
assertEquals(minibatchSize, ds.numExamples()); assertEquals(minibatchSize, ds.numExamples());
@ -144,7 +144,7 @@ public class TestExport extends BaseSparkTest {
for (File file : files) { for (File file : files) {
if (!file.getPath().endsWith(".bin")) if (!file.getPath().endsWith(".bin"))
continue; continue;
System.out.println(file); // System.out.println(file);
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet(); MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
ds.load(file); ds.load(file);
assertEquals(minibatchSize, ds.getFeatures(0).size(0)); assertEquals(minibatchSize, ds.getFeatures(0).size(0));

View File

@ -92,9 +92,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
int[][] colorCountsByPartition = new int[3][2]; int[][] colorCountsByPartition = new int[3][2];
for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) { for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) {
System.out.println(val); // System.out.println(val);
Integer partition = hbp.getPartition(val._1()); Integer partition = hbp.getPartition(val._1());
System.out.println(partition); // System.out.println(partition);
if (val._2().equals("red")) if (val._2().equals("red"))
colorCountsByPartition[partition][0] += 1; colorCountsByPartition[partition][0] += 1;
@ -102,9 +102,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
colorCountsByPartition[partition][1] += 1; colorCountsByPartition[partition][1] += 1;
} }
for (int i = 0; i < 3; i++) { // for (int i = 0; i < 3; i++) {
System.out.println(Arrays.toString(colorCountsByPartition[i])); // System.out.println(Arrays.toString(colorCountsByPartition[i]));
} // }
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
// avg red per partition : 2.33 // avg red per partition : 2.33
assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4); assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4);
@ -178,12 +178,12 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
colorCountsByPartition[partition][1] += 1; colorCountsByPartition[partition][1] += 1;
} }
for (int i = 0; i < numPartitions; i++) { // for (int i = 0; i < numPartitions; i++) {
System.out.println(Arrays.toString(colorCountsByPartition[i])); // System.out.println(Arrays.toString(colorCountsByPartition[i]));
} // }
//
System.out.println("Ideal red # per partition: " + avgRed); // System.out.println("Ideal red # per partition: " + avgRed);
System.out.println("Ideal blue # per partition: " + avgBlue); // System.out.println("Ideal blue # per partition: " + avgBlue);
for (int i = 0; i < numPartitions; i++) { for (int i = 0; i < numPartitions; i++) {
// avg red per partition : 2.33 // avg red per partition : 2.33

View File

@ -115,7 +115,7 @@ public class TestSparkComputationGraph extends BaseSparkTest {
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0); TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm); SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(1))); scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5)));
JavaRDD<MultiDataSet> rdd = sc.parallelize(list); JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
scg.fitMultiDataSet(rdd); scg.fitMultiDataSet(rdd);

View File

@ -31,8 +31,11 @@ import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.junit.Test; import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
@ -45,8 +48,24 @@ import static org.junit.Assert.assertTrue;
@Slf4j @Slf4j
public class TestSparkDl4jMultiLayer extends BaseSparkTest { public class TestSparkDl4jMultiLayer extends BaseSparkTest {
@Test(timeout = 120000L) @Override
public long getTimeoutMilliseconds() {
return 120000L;
}
@Override
public DataType getDataType() {
return DataType.FLOAT;
}
@Override
public DataType getDefaultFPDataType() {
return DataType.FLOAT;
}
@Test
public void testEvaluationSimple() throws Exception { public void testEvaluationSimple() throws Exception {
Nd4j.getRandom().setSeed(12345);
for( int evalWorkers : new int[]{1, 4, 8}) { for( int evalWorkers : new int[]{1, 4, 8}) {
//Simple test to validate DL4J issue 4099 is fixed... //Simple test to validate DL4J issue 4099 is fixed...
@ -75,18 +94,18 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
//---------------------------------- //----------------------------------
//Create network configuration and conduct network training //Create network configuration and conduct network training
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.activation(Activation.LEAKYRELU) .activation(Activation.LEAKYRELU)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.02, 0.9)) .updater(new Adam(1e-3))
.l2(1e-4) .l2(1e-5)
.list() .list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build()) .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build()) .layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build()) .activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
.build(); .build();
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options //Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options

View File

@ -333,15 +333,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd); sparkNet.fit(rdd);
} }
System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); // System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: " // System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat())); // + Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat())); // System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat())); // System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertEquals(initialParams, initialSparkParams); assertEquals(initialParams, initialSparkParams);
assertNotEquals(initialParams, finalParams); assertNotEquals(initialParams, finalParams);
assertEquals(finalParams, finalSparkParams); assertEquals(finalParams, finalSparkParams);
@ -405,15 +406,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd); sparkNet.fit(rdd);
} }
System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); // System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: " // System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat())); // + Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat())); // System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat())); // System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f); assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f); assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
@ -478,18 +480,19 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd); sparkNet.fit(rdd);
} }
System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); // System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
// executioner.addToWatchdog(finalSparkParams, "finalSparkParams"); // executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
float[] fp = finalParams.data().asFloat(); float[] fp = finalParams.data().asFloat();
float[] fps = finalSparkParams.data().asFloat(); float[] fps = finalSparkParams.data().asFloat();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: " // System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat())); // + Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(fp)); // System.out.println("Final (Local) params: " + Arrays.toString(fp));
System.out.println("Final (Spark) params: " + Arrays.toString(fps)); // System.out.println("Final (Spark) params: " + Arrays.toString(fps));
assertEquals(initialParams, initialSparkParams); assertEquals(initialParams, initialSparkParams);
assertNotEquals(initialParams, finalParams); assertNotEquals(initialParams, finalParams);
@ -551,14 +554,15 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd); sparkNet.fit(rdd);
} }
System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); // System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat())); // System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat())); // System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat())); // System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f); assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f); assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);

View File

@ -37,7 +37,7 @@ public class TestJsonYaml {
String json = tm.toJson(); String json = tm.toJson();
String yaml = tm.toYaml(); String yaml = tm.toYaml();
System.out.println(json); // System.out.println(json);
TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json); TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json);
TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml); TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml);

View File

@ -389,7 +389,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs"); List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs");
for (EventStats e : workerFitStats) { for (EventStats e : workerFitStats) {
ExampleCountEventStats eces = (ExampleCountEventStats) e; ExampleCountEventStats eces = (ExampleCountEventStats) e;
System.out.println(eces.getTotalExampleCount()); // System.out.println(eces.getTotalExampleCount());
} }
for (EventStats e : workerFitStats) { for (EventStats e : workerFitStats) {
@ -457,7 +457,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter); assertNotEquals(paramsBefore, paramsAfter);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString()); // System.out.println(stats.statsAsString());
stats.statsAsString();
sparkNet.getTrainingMaster().deleteTempFiles(sc); sparkNet.getTrainingMaster().deleteTempFiles(sc);
} }
@ -483,7 +484,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
i++; i++;
} }
System.out.println("Saved to: " + tempDirF.getAbsolutePath()); // System.out.println("Saved to: " + tempDirF.getAbsolutePath());
@ -527,7 +528,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
//Expect //Expect
System.out.println(stats.statsAsString()); // System.out.println(stats.statsAsString());
stats.statsAsString();
assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size()); assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size());
List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs"); List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs");
@ -566,8 +568,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
i++; i++;
} }
System.out.println("Saved to: " + tempDirF.getAbsolutePath()); // System.out.println("Saved to: " + tempDirF.getAbsolutePath());
System.out.println("Saved to: " + tempDirF2.getAbsolutePath()); // System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
@ -610,7 +612,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter); assertNotEquals(paramsBefore, paramsAfter);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString()); // System.out.println(stats.statsAsString());
stats.statsAsString();
//Same thing, buf for MultiDataSet objects: //Same thing, buf for MultiDataSet objects:
config = new Configuration(); config = new Configuration();
@ -631,7 +634,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter); assertNotEquals(paramsBefore, paramsAfter);
stats = sparkNet.getSparkTrainingStats(); stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString()); // System.out.println(stats.statsAsString());
stats.statsAsString();
} }
@ -730,13 +734,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.build(); .build();
for (int avgFreq : new int[] {1, 5, 10}) { for (int avgFreq : new int[] {1, 5, 10}) {
System.out.println("--- Avg freq " + avgFreq + " ---"); // System.out.println("--- Avg freq " + avgFreq + " ---");
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(), SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(),
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize) new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq) .batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
.repartionData(Repartition.Always).build()); .repartionData(Repartition.Always).build());
sparkNet.setListeners(new ScoreIterationListener(1)); sparkNet.setListeners(new ScoreIterationListener(5));
@ -778,13 +782,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.setOutputs("1").build(); .setOutputs("1").build();
for (int avgFreq : new int[] {1, 5, 10}) { for (int avgFreq : new int[] {1, 5, 10}) {
System.out.println("--- Avg freq " + avgFreq + " ---"); // System.out.println("--- Avg freq " + avgFreq + " ---");
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(), SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(),
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize) new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq) .batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
.repartionData(Repartition.Always).build()); .repartionData(Repartition.Always).build());
sparkNet.setListeners(new ScoreIterationListener(1)); sparkNet.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> rdd = sc.parallelize(list); JavaRDD<DataSet> rdd = sc.parallelize(list);

View File

@ -107,7 +107,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
expectedStatNames.addAll(c); expectedStatNames.addAll(c);
} }
System.out.println(expectedStatNames); // System.out.println(expectedStatNames);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
@ -119,7 +119,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
} }
String statsAsString = stats.statsAsString(); String statsAsString = stats.statsAsString();
System.out.println(statsAsString); // System.out.println(statsAsString);
assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat

View File

@ -35,7 +35,7 @@ public class TestTimeSource {
long systemTime = System.currentTimeMillis(); long systemTime = System.currentTimeMillis();
long ntpTime = timeSource.currentTimeMillis(); long ntpTime = timeSource.currentTimeMillis();
long offset = ntpTime - systemTime; long offset = ntpTime - systemTime;
System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset); // System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
Thread.sleep(500); Thread.sleep(500);
} }
} }
@ -49,7 +49,7 @@ public class TestTimeSource {
long systemTime = System.currentTimeMillis(); long systemTime = System.currentTimeMillis();
long ntpTime = timeSource.currentTimeMillis(); long ntpTime = timeSource.currentTimeMillis();
long offset = ntpTime - systemTime; long offset = ntpTime - systemTime;
System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset); // System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next
Thread.sleep(500); Thread.sleep(500);
} }

View File

@ -87,7 +87,7 @@ public class TestListeners extends BaseSparkTest {
net.fit(rdd); net.fit(rdd);
List<String> sessions = ss.listSessionIDs(); List<String> sessions = ss.listSessionIDs();
System.out.println("Sessions: " + sessions); // System.out.println("Sessions: " + sessions);
assertEquals(1, sessions.size()); assertEquals(1, sessions.size());
String sid = sessions.get(0); String sid = sessions.get(0);
@ -95,15 +95,15 @@ public class TestListeners extends BaseSparkTest {
List<String> typeIDs = ss.listTypeIDsForSession(sid); List<String> typeIDs = ss.listTypeIDsForSession(sid);
List<String> workers = ss.listWorkerIDsForSession(sid); List<String> workers = ss.listWorkerIDsForSession(sid);
System.out.println(sid + "\t" + typeIDs + "\t" + workers); // System.out.println(sid + "\t" + typeIDs + "\t" + workers);
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID); List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
System.out.println(lastUpdates); // System.out.println(lastUpdates);
System.out.println("Static info:"); // System.out.println("Static info:");
for (String wid : workers) { for (String wid : workers) {
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid); Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
System.out.println(sid + "\t" + wid); // System.out.println(sid + "\t" + wid);
} }
assertEquals(1, typeIDs.size()); assertEquals(1, typeIDs.size());

View File

@ -63,7 +63,7 @@ public class TestRepartitioning extends BaseSparkTest {
assertEquals(10, rdd2.partitions().size()); assertEquals(10, rdd2.partitions().size());
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
List<String> partition = rdd2.collectPartitions(new int[] {i})[0]; List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
System.out.println("Partition " + i + " size: " + partition.size()); // System.out.println("Partition " + i + " size: " + partition.size());
assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition) assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
} }
} }
@ -170,7 +170,7 @@ public class TestRepartitioning extends BaseSparkTest {
List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect(); List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
System.out.println(partitionCounts); // System.out.println(partitionCounts);
List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList( List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList(
new Tuple2<>(0,29), new Tuple2<>(0,29),
@ -185,7 +185,7 @@ public class TestRepartitioning extends BaseSparkTest {
JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112); JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112);
List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect(); List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
System.out.println(partitionCountsAfter); // System.out.println(partitionCountsAfter);
for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){ for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){
assertEquals(2, (int)t2._2()); assertEquals(2, (int)t2._2());
@ -219,8 +219,8 @@ public class TestRepartitioning extends BaseSparkTest {
} }
} }
System.out.println("min: " + min + "\t@\t" + minIdx); // System.out.println("min: " + min + "\t@\t" + minIdx);
System.out.println("max: " + max + "\t@\t" + maxIdx); // System.out.println("max: " + max + "\t@\t" + maxIdx);
assertEquals(1, min); assertEquals(1, min);
assertEquals(2, max); assertEquals(2, max);
@ -244,7 +244,7 @@ public class TestRepartitioning extends BaseSparkTest {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
List<String> partition = rdd2.collectPartitions(new int[] {i})[0]; List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
System.out.println("Partition " + i + " size: " + partition.size()); // System.out.println("Partition " + i + " size: " + partition.size());
assertTrue(partition.size() >= 90 && partition.size() <= 110); assertTrue(partition.size() >= 90 && partition.size() <= 110);
} }
} }

View File

@ -5,7 +5,7 @@ project(mkldnn-download NONE)
include(ExternalProject) include(ExternalProject)
ExternalProject_Add(mkldnn ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
GIT_TAG v1.1.3 GIT_TAG v1.2
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""

View File

@ -999,14 +999,14 @@ namespace nd4j {
* set new order and shape in case of suitable array length (in-place operation) * set new order and shape in case of suitable array length (in-place operation)
* order - order to set * order - order to set
* shape - shape to set * shape - shape to set
* * copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping
* if there was permute applied before or there are weird strides, then new buffer is allocated for array * if there was permute applied before or there are weird strides, then new buffer is allocated for array
*/ */
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape); bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const char order, const std::vector<Nd4jLong>& shape); bool reshapei(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const std::initializer_list<Nd4jLong>& shape); bool reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const std::vector<Nd4jLong>& shape); bool reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
/** /**
* creates new array with corresponding order and shape, new array will point on _buffer of this array * creates new array with corresponding order and shape, new array will point on _buffer of this array
@ -1015,8 +1015,8 @@ namespace nd4j {
* *
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array * if permute have been applied before or there are weird strides, then new buffer is allocated for new array
*/ */
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const &; NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) const &;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) &&; NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) &&;
/** /**
* calculate strides and set given order * calculate strides and set given order

View File

@ -501,7 +501,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e]; auto cdata = data + offsets[e];
if (dataType == DataType::UTF16) { if (dataType == DataType::UTF16) {
unicode::utf8to16(string[e], cdata, std::char_traits<char>::length(string[e])); unicode::utf8to16(string[e], cdata, std::char_traits<char>::length(string[e]));
@ -568,7 +568,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::stri
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e]; auto cdata = data + offsets[e];
if (dataType == DataType::UTF16) { if (dataType == DataType::UTF16) {
unicode::utf8to16(string[e].data(), cdata, string[e].size()); unicode::utf8to16(string[e].data(), cdata, string[e].size());
@ -635,7 +635,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16s
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e]; auto cdata = data + offsets[e];
if (dtype == DataType::UTF16) { if (dtype == DataType::UTF16) {
memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t)); memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t));
@ -701,7 +701,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e]; auto cdata = data + offsets[e];
if (dtype == DataType::UTF16) { if (dtype == DataType::UTF16) {
memcpy(cdata, string[e], std::char_traits<char16_t>::length(string[e]) * sizeof(uint16_t)); memcpy(cdata, string[e], std::char_traits<char16_t>::length(string[e]) * sizeof(uint16_t));
@ -767,7 +767,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e]; auto cdata = data + offsets[e];
if (dtype == DataType::UTF16) { if (dtype == DataType::UTF16) {
unicode::utf32to16(string[e].data(), cdata, string[e].size()); unicode::utf32to16(string[e].data(), cdata, string[e].size());
@ -833,7 +833,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto cdata = data + offsets[e]; auto cdata = data + offsets[e];
if (dtype == DataType::UTF16) { if (dtype == DataType::UTF16) {
unicode::utf32to16(string[e], cdata, std::char_traits<char32_t>::length(string[e])); unicode::utf32to16(string[e], cdata, std::char_traits<char32_t>::length(string[e]));
@ -1197,8 +1197,8 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched"); throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
} }
// memcpy is allowed only for same order && same ews (being equal to 1) // memcpy is allowed only for same order c && same ews (being equal to 1)
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1) if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT()); copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
else { else {
NDArray::prepareSpecialUse({this}, {&other}); NDArray::prepareSpecialUse({this}, {&other});
@ -1569,20 +1569,25 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector<int>& dimensions) cons
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::printShapeInfo(const char * msg) const { void NDArray::printShapeInfo(const char * msg) const {
//shape::printShapeInfo(_shapeInfo);
if (msg == nullptr)
shape::printShapeInfoLinear(_shapeInfo);
else {
int rank = shape::rank(_shapeInfo); int rank = shape::rank(_shapeInfo);
int lim = shape::shapeInfoLength(rank); int lim = shape::shapeInfoLength(rank);
printf("%s: [", msg);
for (int i = 0; i < shape::shapeInfoLength(rank); i++) { if(msg != nullptr)
printf("%lld", (long long) _shapeInfo[i]); printf("shapeInfo %s: [", msg);
if (i < lim - 1) else
printf(", "); printf("shapeInfo: [");
}
printf("]\n"); printf("%i, ", rank);
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
if(i == rank + 1)
printf(" ");
printf("%lld,", _shapeInfo[i]);
} }
printf(" %lld,", shape::type(_shapeInfo));
printf("%lld,", shape::elementWiseStride(_shapeInfo));
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
fflush(stdout); fflush(stdout);
} }
@ -1855,19 +1860,19 @@ void NDArray::updateStrides(const char order) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// set new order and shape in case of suitable array length // set new order and shape in case of suitable array length
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape) { bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
std::vector<Nd4jLong> vShape(shape); std::vector<Nd4jLong> vShape(shape);
return reshapei(order, vShape); return reshapei(order, vShape, copyToNewBuff);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape) { bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
return reshapei('c', shape); return reshapei(ordering(), shape, copyToNewBuff);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape) { bool NDArray::reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) {
return reshapei('c', shape); return reshapei(ordering(), shape, copyToNewBuff);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -1918,18 +1923,18 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// create new array with corresponding order and shape, new array will point to the same _buffer as this array // create new array with corresponding order and shape, new array will point to the same _buffer as this array
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const & { NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) const & {
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset()); NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
newArr.reshapei(order, shape); newArr.reshapei(order, shape, copyToNewBuff);
return newArr; return newArr;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) && { NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) && {
this->reshapei(order, shape); this->reshapei(order, shape, copyToNewBuff);
return std::move(*this); return std::move(*this);
} }
@ -1971,7 +1976,7 @@ bool NDArray::permutei(const std::initializer_list<int>& dimensions) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
bool NDArray::permutei(const std::vector<int>& dimensions) { bool NDArray::permutei(const std::vector<int>& dimensions) {
return permutei(dimensions.data(), dimensions.size()); return permutei(dimensions.data(), rankOf());
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -1993,7 +1998,7 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
for (int e = 0; e < dimensions.size(); e++) for (int e = 0; e < dimensions.size(); e++)
ivec[e] = dimensions[e]; ivec[e] = dimensions[e];
return permutei(ivec.data(), ivec.size()); return permutei(ivec.data(), rankOf());
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -2029,9 +2034,8 @@ NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray NDArray::permute(const std::vector<int>& dimensions) const &{ NDArray NDArray::permute(const std::vector<int>& dimensions) const &{
auto data = dimensions.data();
auto size = dimensions.size(); return permute(dimensions.data(), rankOf());
return permute(data, size);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -2043,7 +2047,8 @@ NDArray NDArray::permute(const std::vector<int>& dimensions) && {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & { NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & {
return permute(dimensions.data(), dimensions.size());
return permute(dimensions.data(), rankOf());
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -2106,12 +2111,12 @@ void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& targe
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const { void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const {
permute(dimensions.data(), dimensions.size(), target); permute(dimensions.data(), rankOf(), target);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const { void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const {
permute(dimensions.data(), dimensions.size(), target); permute(dimensions.data(), rankOf(), target);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -2362,7 +2367,7 @@ NDArray NDArray::asS() const {
const auto inData = bufferAsT<int8_t>() + offsetsLength; const auto inData = bufferAsT<int8_t>() + offsetsLength;
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (int e = start; e < stop; e += increment) { for (int e = start; e < stop; e++) {
auto cdata = outData + offsets[e]; auto cdata = outData + offsets[e];
auto end = nInputoffsets[e + 1]; auto end = nInputoffsets[e + 1];
auto idata = inData + nInputoffsets[e]; auto idata = inData + nInputoffsets[e];
@ -3221,7 +3226,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LI
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// set new order and shape in case of suitable array length // set new order and shape in case of suitable array length
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) { bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape, const bool copyToNewBuff) {
// check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary // check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data())) if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
@ -3293,18 +3298,14 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
Nd4jLong *shapeInfoNew; Nd4jLong *shapeInfoNew;
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew); bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew);
// we can do this only if there was no permute applied, or there are no weird strides
if (canReshape) { if (canReshape) {
if(ordering() == 'c' && order == 'f')
throw std::invalid_argument("NDArray::reshapei(order, shape): in case of reshapeC it doesn't make sense to reshape from c order to f order !");
shape::setEws(shapeInfoNew, arrLength);
setShapeInfo(shapeInfoNew); setShapeInfo(shapeInfoNew);
} }
else { else {
NDArray temp(order, shape, dataType(), getContext()); NDArray temp(order, shape, dataType(), getContext());
if(copyToNewBuff)
this->applyTransform(transform::Assign, temp, nullptr); this->applyTransform(transform::Assign, temp, nullptr);
*this = std::move(temp); *this = std::move(temp);
} }
@ -3465,7 +3466,7 @@ NDArray NDArray::dup(const char newOrder) const {
std::vector<std::string> strings(lengthOf()); std::vector<std::string> strings(lengthOf());
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
strings[i] = std::move(this->e<std::string>(i)); strings[i] = std::move(this->e<std::string>(i));
} }
}; };
@ -3478,7 +3479,7 @@ NDArray NDArray::dup(const char newOrder) const {
std::vector<std::u16string> strings(lengthOf()); std::vector<std::u16string> strings(lengthOf());
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
strings[i] = std::move(this->e<std::u16string>(i)); strings[i] = std::move(this->e<std::u16string>(i));
} }
}; };
@ -3490,7 +3491,7 @@ NDArray NDArray::dup(const char newOrder) const {
std::vector<std::u32string> strings(lengthOf()); std::vector<std::u32string> strings(lengthOf());
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
strings[i] = std::move(this->e<std::u32string>(i)); strings[i] = std::move(this->e<std::u32string>(i));
} }
}; };
@ -4846,7 +4847,7 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
auto shapeOf = shape::shapeOf(newShapeInfo); auto shapeOf = shape::shapeOf(newShapeInfo);
auto stridesOf = shape::stride(newShapeInfo); auto stridesOf = shape::stride(newShapeInfo);
Nd4jLong offset(0), subArrLen(1); Nd4jLong offset = 0;
int n(isStrided ? 3 : 2), first, last, stride; int n(isStrided ? 3 : 2), first, last, stride;
for (int d = rank - 1; d >= 0; --d) { for (int d = rank - 1; d >= 0; --d) {
@ -4863,29 +4864,31 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
if(shapeOf[d] != 1) if(shapeOf[d] != 1)
stridesOf[d] *= stride; stridesOf[d] *= stride;
} }
}
subArrLen *= shapeOf[d]; Nd4jLong *newShapeInfo2 = newShapeInfo;
if(!keepUnitiesInShape) {
std::vector<int> dimsWithUnities;
for (uint d = 0; d < rank; ++d)
if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1)
dimsWithUnities.push_back(d);
if(!dimsWithUnities.empty())
newShapeInfo2 = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace());
} }
// check if there is possibility to set ews = 1 // check if there is possibility to set ews = 1
shape::setEws(newShapeInfo, subArrLen); shape::checkStridesEwsAndOrder(newShapeInfo2);
NDArray result(_buffer, ShapeDescriptor(newShapeInfo), getContext(), offset + getBufferOffset()); NDArray result(_buffer, ShapeDescriptor(newShapeInfo2), getContext(), offset + getBufferOffset());
result._isView = true; result._isView = true;
if(!keepUnitiesInShape) {
const int coeff = isStrided ? 3 : 2;
std::vector<Nd4jLong> nonUnitDims;
for (int d = 0; d < rank; ++d)
if(!(idx[coeff*d] != idx[coeff*d+1] && newShapeInfo[d+1] == 1))
nonUnitDims.push_back(newShapeInfo[d+1]);
if(nonUnitDims.size() != rank)
result.reshapei(nonUnitDims);
}
RELEASE(newShapeInfo, getContext()->getWorkspace()); RELEASE(newShapeInfo, getContext()->getWorkspace());
if(newShapeInfo != newShapeInfo2)
RELEASE(newShapeInfo2, getContext()->getWorkspace());
return result; return result;
} }

View File

@ -179,7 +179,7 @@ namespace graph {
nd4j_debug("Embedded graph execution finished. %i variable(s) migrated\n", cnt); nd4j_debug("Embedded graph execution finished. %i variable(s) migrated\n", cnt);
} else if (node->hasCustomOp()) { } else if (node->hasCustomOp()) {
// if we have something to execute - lets just execute it. // now, if we have something to execute - lets just execute it.
auto status = node->getCustomOp()->execute(&context); auto status = node->getCustomOp()->execute(&context);
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return status; return status;
@ -494,8 +494,10 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
nd4j::memory::MemoryRegistrator::getInstance()->setGraphMemoryFootprintIfGreater(h, m); nd4j::memory::MemoryRegistrator::getInstance()->setGraphMemoryFootprintIfGreater(h, m);
} }
if (tempFlow) if (tempFlow) {
delete flowPath; delete flowPath;
__variableSpace->setFlowPath(nullptr);
}
return Status::OK(); return Status::OK();
} }

View File

@ -98,7 +98,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
Nd4jLong coords[MAX_RANK]; Nd4jLong coords[MAX_RANK];
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
shape::index2coords(i, target.getShapeInfo(), coords); shape::index2coords(i, target.getShapeInfo(), coords);
const auto zOffset = shape::getOffset(target.getShapeInfo(), coords); const auto zOffset = shape::getOffset(target.getShapeInfo(), coords);
@ -152,7 +152,7 @@ static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) {
auto y = reinterpret_cast<T *>(yBuffer); auto y = reinterpret_cast<T *>(yBuffer);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto temp = x[i]; auto temp = x[i];
x[i] = y[i]; x[i] = y[i];
y[i] = temp; y[i] = temp;
@ -266,7 +266,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
if(result.ordering() == 'c') { // ews == 1 always here if(result.ordering() == 'c') { // ews == 1 always here
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo()); auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES);
} }
@ -277,7 +277,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
else { else {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto xOffset = result.getOffset(i); auto xOffset = result.getOffset(i);
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo()); auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES);
@ -377,7 +377,7 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
// loop through input array // loop through input array
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
Nd4jLong coords[MAX_RANK]; Nd4jLong coords[MAX_RANK];
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
shape::index2coords(i, output.getShapeInfo(), coords); shape::index2coords(i, output.getShapeInfo(), coords);
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords); const auto zOffset = shape::getOffset(output.getShapeInfo(), coords);

View File

@ -22,7 +22,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) { if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e++)
z[e] = func(f[e], s[e], t[e]); z[e] = func(f[e], s[e], t[e]);
}; };
@ -31,7 +31,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
if (f == z) { if (f == z) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto tOffset = this->getOffset(e); auto tOffset = this->getOffset(e);
auto uOffset = second.getOffset(e); auto uOffset = second.getOffset(e);
auto vOffset = third.getOffset(e); auto vOffset = third.getOffset(e);
@ -44,7 +44,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
} else { } else {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto tOffset = this->getOffset(e); auto tOffset = this->getOffset(e);
auto uOffset = second.getOffset(e); auto uOffset = second.getOffset(e);
auto vOffset = third.getOffset(e); auto vOffset = third.getOffset(e);
@ -93,7 +93,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e++)
z[e] = func(f[e], s[e]); z[e] = func(f[e], s[e]);
}; };
@ -102,7 +102,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
if (f == z) { if (f == z) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other.getOffset(e); auto yOffset = other.getOffset(e);
@ -114,7 +114,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
} else { } else {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other.getOffset(e); auto yOffset = other.getOffset(e);
auto zOffset = target.getOffset(e); auto zOffset = target.getOffset(e);
@ -156,7 +156,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) { if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e++)
z[e] = func(f[e]); z[e] = func(f[e]);
}; };
@ -165,7 +165,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
if (f == z) { if (f == z) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
f[xOffset] = func(f[xOffset]); f[xOffset] = func(f[xOffset]);
@ -176,7 +176,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
} else { } else {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto zOffset = target.getOffset(e); auto zOffset = target.getOffset(e);
@ -217,7 +217,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) { if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e++)
z[e] = func(e, f[e]); z[e] = func(e, f[e]);
}; };
@ -226,7 +226,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
if (f == z) { if (f == z) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
f[xOffset] = func(e, f[xOffset]); f[xOffset] = func(e, f[xOffset]);
@ -237,7 +237,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
} else { } else {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto zOffset = target.getOffset(e); auto zOffset = target.getOffset(e);
@ -283,7 +283,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e++)
z[e] = func((Nd4jLong) e, f[e], s[e]); z[e] = func((Nd4jLong) e, f[e], s[e]);
}; };
@ -292,7 +292,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
if (f == z) { if (f == z) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other.getOffset(e); auto yOffset = other.getOffset(e);
@ -304,7 +304,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
} else { } else {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other.getOffset(e); auto yOffset = other.getOffset(e);
auto zOffset = target.getOffset(e); auto zOffset = target.getOffset(e);

View File

@ -163,15 +163,44 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
#else #else
auto loopKind = nd4j::LoopKind::deduceKindOfLoopBroadcast(hXShapeInfo, hYShapeInfo, hZShapeInfo);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), LIBND4J_TYPES);
}; };
Nd4jLong numTads = 0;
switch (loopKind) {
case nd4j::LoopKind::BROADCAST_SCALAR_X: {
numTads = shape::length(hXShapeInfo);
}
break;
case nd4j::LoopKind::BROADCAST_SCALAR_Y: {
numTads = shape::length(hYShapeInfo);
}
break;
case nd4j::LoopKind::BROADCAST_3D: {
numTads = shape::sizeAt(hZShapeInfo, 0);
}
break;
case nd4j::LoopKind::BROADCAST_4D: {
numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1);
}
break;
case nd4j::LoopKind::BROADCAST_5D: {
numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1);
}
break;
default: {
auto xLen = shape::length(hXShapeInfo); auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo); auto yLen = shape::length(hYShapeInfo);
auto numTads = xLen / yLen; numTads = xLen / yLen;
}
}
samediff::Threads::parallel_tad(func, 0, numTads); samediff::Threads::parallel_tad(func, 0, numTads);
#endif #endif
} }

View File

@ -1291,7 +1291,7 @@ void pullRowsGeneric(void *vx,
_threads = nd4j::math::nd4j_min<int>(_threads, nd4j::Environment::getInstance()->maxThreads()); _threads = nd4j::math::nd4j_min<int>(_threads, nd4j::Environment::getInstance()->maxThreads());
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto idx = start; idx < stop; idx += increment) { for (auto idx = start; idx < stop; idx++) {
auto xTadOffsetForBlock = tadOffsets[indexes[idx]]; auto xTadOffsetForBlock = tadOffsets[indexes[idx]];
auto zTadOffsetForBlock = zTadOffsets[idx]; auto zTadOffsetForBlock = zTadOffsets[idx];
@ -1356,7 +1356,7 @@ void tearGeneric(void *vx,
auto numTads = shape::length(hXShapeInfo) / tadLength; auto numTads = shape::length(hXShapeInfo) / tadLength;
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto hZ = reinterpret_cast<T *>(targets[i]); auto hZ = reinterpret_cast<T *>(targets[i]);
auto s = hX + tadOffsets[i]; auto s = hX + tadOffsets[i];
@ -1478,7 +1478,7 @@ void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZS
auto dZ = reinterpret_cast<T **>(dz); auto dZ = reinterpret_cast<T **>(dz);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto f = start; f < stop; f += increment) { for (auto f = start; f < stop; f++) {
auto hX = reinterpret_cast<T *>(dX[f]); auto hX = reinterpret_cast<T *>(dX[f]);
//auto hZ = reinterpret_cast<T *>(dZ[f]); //auto hZ = reinterpret_cast<T *>(dZ[f]);

View File

@ -52,7 +52,7 @@ namespace nd4j {
TypeCast::convertGeneric<T2, T>(nullptr, tmp, length, buffer); TypeCast::convertGeneric<T2, T>(nullptr, tmp, length, buffer);
#else #else
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e++)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
}; };
@ -110,7 +110,7 @@ namespace nd4j {
TypeCast::convertGeneric<float, T>(nullptr, tmp, length, buffer); TypeCast::convertGeneric<float, T>(nullptr, tmp, length, buffer);
#else #else
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e++)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
}; };
@ -138,7 +138,7 @@ namespace nd4j {
#else #else
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e++)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
}; };
@ -164,7 +164,7 @@ namespace nd4j {
TypeCast::convertGeneric<float16, T>(nullptr, tmp, length, buffer); TypeCast::convertGeneric<float16, T>(nullptr, tmp, length, buffer);
#else #else
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e++)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
}; };

View File

@ -58,6 +58,7 @@ namespace nd4j {
virtual void putVariable(int id, Variable *variable); virtual void putVariable(int id, Variable *variable);
virtual void putVariable(int id, NDArray *array); virtual void putVariable(int id, NDArray *array);
virtual void putVariable(int id, int idx, NDArray *array); virtual void putVariable(int id, int idx, NDArray *array);
virtual void putVariable(int id, int idx, NDArray &array);
virtual void putVariable(int id, int idx, Variable *array); virtual void putVariable(int id, int idx, Variable *array);
virtual void replaceVariable(Variable *variable); virtual void replaceVariable(Variable *variable);

View File

@ -100,6 +100,7 @@ namespace nd4j {
virtual void putVariable(int id, Variable *variable); virtual void putVariable(int id, Variable *variable);
virtual void putVariable(int id, NDArray *array); virtual void putVariable(int id, NDArray *array);
virtual void putVariable(int id, int idx, NDArray *array); virtual void putVariable(int id, int idx, NDArray *array);
virtual void putVariable(int id, int idx, NDArray &array);
virtual void putVariable(int id, int idx, Variable *array); virtual void putVariable(int id, int idx, Variable *array);
virtual void dropVariable(std::pair<int,int> &pair); virtual void dropVariable(std::pair<int,int> &pair);

View File

@ -1088,8 +1088,23 @@ namespace nd4j {
if (e < node->input()->size() - 1) if (e < node->input()->size() - 1)
nd4j_printf(", ", ""); nd4j_printf(", ", "");
} }
if (node->opType() == OpType_CUSTOM) {
auto ctx = node->protoContext();
if (ctx->getIArguments()->size() > 0) {
printf("]; iArgs: [");
for (int e = 0; e < ctx->getIArguments()->size(); e++) {
printf("%i", ctx->getIArguments()->at(e));
if (e < ctx->getIArguments()->size() - 1)
nd4j_printf(", ", "");
}
}
}
nd4j_printf("]; \n", ""); nd4j_printf("]; \n", "");
// printf("\n"); // printf("\n");
fflush(stdout); fflush(stdout);
} }

View File

@ -60,8 +60,11 @@ namespace nd4j {
result->_name = this->_name; result->_name = this->_name;
result->_index = this->_index; result->_index = this->_index;
if (this->_ndarray != nullptr) if (this->_ndarray != nullptr) {
result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering())); result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering()));
result->_readOnly = false;
result->_removable = true;
}
if (this->_list != nullptr) if (this->_list != nullptr)
result->_list = this->_list->clone(); result->_list = this->_list->clone();

View File

@ -191,6 +191,9 @@ namespace nd4j {
_current->putVariable(id, array); _current->putVariable(id, array);
} }
void nd4j::graph::VariableProxy::putVariable(int id, int idx, NDArray &array) {
_current->putVariable(id, idx, array);
}
void VariableProxy::putVariable(int id, int idx, NDArray *array) { void VariableProxy::putVariable(int id, int idx, NDArray *array) {
_current->putVariable(id, idx, array); _current->putVariable(id, idx, array);

View File

@ -263,18 +263,18 @@ namespace nd4j {
void nd4j::graph::VariableSpace::putVariable(int id, Variable *variable) { void nd4j::graph::VariableSpace::putVariable(int id, Variable *variable) {
// we don't want to add variables more then once // we don't want to add variables more then once
if (_variables.count(id) > 0 || _temporary.count(id) > 0) { if (_variables.count(id) > 0 || _temporary.count(id) > 0) {
// nd4j_verbose("Trying to update variable for node_%i\n", id);
auto local = id < 0 ? _variables.at(id) : _temporary.at(id); auto local = id < 0 ? _variables.at(id) : _temporary.at(id);
if (!local->hasNDArray() && variable->hasNDArray()) { if (!local->hasNDArray() && variable->hasNDArray()) {
// nd4j_verbose("Saving variable for node_%i\n", id);
local->setNDArray(variable->getNDArray()); local->setNDArray(variable->getNDArray());
}
return; // we're inheriting this from Variable
local->markReadOnly(variable->isReadOnly());
local->markRemovable(variable->isRemovable());
} }
//nd4j_debug("Adding Variable to Space: id: %i; Array is null: %i;\n", id, variable->getNDArray() == nullptr); return;
}
_varmap.lock(); _varmap.lock();
@ -314,6 +314,21 @@ namespace nd4j {
} }
} }
void nd4j::graph::VariableSpace::putVariable(int id, int idx, NDArray &array) {
auto *var = new nd4j::graph::Variable(&array, "", id, idx);
var->markRemovable(false);
var->markReadOnly(true);
// let's see if this op needs
bool d = this->hasVariable(id, idx);
this->putVariable(id, var);
// if var for this nodeid already exists - we'll just delete variable
if (d)
delete var;
}
void nd4j::graph::VariableSpace::putVariable(int id, NDArray *array) { void nd4j::graph::VariableSpace::putVariable(int id, NDArray *array) {
auto *var = new nd4j::graph::Variable(array); auto *var = new nd4j::graph::Variable(array);
this->putVariable(id, var); this->putVariable(id, var);

View File

@ -24,6 +24,7 @@
#include <pointercast.h> #include <pointercast.h>
#include <dll.h> #include <dll.h>
#include <string> #include <string>
#include <vector>
namespace nd4j { namespace nd4j {
namespace graph { namespace graph {
@ -65,6 +66,9 @@ namespace nd4j {
// total amount of memory used during execution // total amount of memory used during execution
Nd4jLong _memoryTotal = 0L; Nd4jLong _memoryTotal = 0L;
std::vector<std::string> _inputShapes;
std::vector<std::string> _outputShapes;
public: public:
NodeProfile() = default; NodeProfile() = default;
~NodeProfile() = default; ~NodeProfile() = default;
@ -84,10 +88,15 @@ namespace nd4j {
void setObjectsSize(Nd4jLong bytes); void setObjectsSize(Nd4jLong bytes);
void setTotalSize(Nd4jLong bytes); void setTotalSize(Nd4jLong bytes);
Nd4jLong getActivationsSize(); void addInputShape(Nd4jLong *shapeInfo);
Nd4jLong getTemporarySize(); void addOutputShape(Nd4jLong *shapeInfo);
Nd4jLong getObjectsSize();
Nd4jLong getTotalSize(); Nd4jLong getActivationsSize() const;
Nd4jLong getTemporarySize() const;
Nd4jLong getObjectsSize() const;
Nd4jLong getTotalSize() const;
Nd4jLong getExecutionTime() const;
std::string& name(); std::string& name();

View File

@ -21,6 +21,8 @@
#include <graph/profiling/GraphProfile.h> #include <graph/profiling/GraphProfile.h>
#include <helpers/logger.h> #include <helpers/logger.h>
#include <chrono> #include <chrono>
#include <templatemath.h>
#include <algorithm>
namespace nd4j { namespace nd4j {
namespace graph { namespace graph {
@ -184,8 +186,25 @@ namespace nd4j {
if (_profiles.empty()) if (_profiles.empty())
nd4j_printf("No nodes in graph\n",""); nd4j_printf("No nodes in graph\n","");
for (auto v: _profiles) // printint out stuff
std::vector<NodeProfile*> sorted;
for (auto v: _profiles) {
v->printOut(); v->printOut();
sorted.emplace_back(v);
}
if (_profiles.size() > 1) {
// building hot spots
std::sort(sorted.begin(), sorted.end(), [](const NodeProfile *a, const NodeProfile *b) -> bool {
return a->getExecutionTime() > b->getExecutionTime();
});
nd4j_printf("\nTop 30 reports by EXEC:\n", "");
auto limit = nd4j::math::nd4j_min<int>(30, sorted.size());
for (int e = 0; e < limit; e++) {
sorted[e]->printOut();
}
}
nd4j_printf("\nSpecial timers:\n", ""); nd4j_printf("\nSpecial timers:\n", "");
if (_timings.empty()) if (_timings.empty())

View File

@ -32,7 +32,7 @@ namespace nd4j {
// graph->printOut(); // graph->printOut();
// warm up // warm up
for (int e = 0; e < 1000; e++) { for (int e = 0; e < iterations; e++) {
FlowPath fp; FlowPath fp;
auto _vs = varSpace->clone(); auto _vs = varSpace->clone();

View File

@ -20,6 +20,7 @@
#include <helpers/logger.h> #include <helpers/logger.h>
#include <graph/profiling/NodeProfile.h> #include <graph/profiling/NodeProfile.h>
#include <helpers/ShapeUtils.h>
namespace nd4j { namespace nd4j {
namespace graph { namespace graph {
@ -35,9 +36,23 @@ namespace nd4j {
nd4j_printf(" Memory: ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", _memoryActivations / _merges, _memoryTemporary / _merges, _memoryObjects / _merges, _memoryTotal / _merges); nd4j_printf(" Memory: ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", _memoryActivations / _merges, _memoryTemporary / _merges, _memoryObjects / _merges, _memoryTotal / _merges);
nd4j_printf(" Time: PREP: %lld ns; EXEC: %lld ns; TTL: %lld ns;\n", _preparationTime / _merges, _executionTime / _merges, _totalTime / _merges); nd4j_printf(" Time: PREP: %lld ns; EXEC: %lld ns; TTL: %lld ns;\n", _preparationTime / _merges, _executionTime / _merges, _totalTime / _merges);
nd4j_printf(" PREP: INPUT: %lld ns; SHAPE: %lld ns; ARRAY: %lld ns;\n", _inputTime / _merges, _shapeTime / _merges, _arrayTime / _merges); nd4j_printf(" PREP: INPUT: %lld ns; SHAPE: %lld ns; ARRAY: %lld ns;\n", _inputTime / _merges, _shapeTime / _merges, _arrayTime / _merges);
std::string inputs;
std::string outputs;
int cnt = 0;
for (const auto &v: _inputShapes)
inputs += v + " ";
for (const auto &v: _outputShapes)
outputs += v + " ";
nd4j_printf(" Inputs: %s\n", inputs.c_str());
nd4j_printf(" Outputs: %s\n", outputs.c_str());
}; };
Nd4jLong NodeProfile::getActivationsSize() { Nd4jLong NodeProfile::getActivationsSize() const {
return _memoryActivations; return _memoryActivations;
} }
@ -53,15 +68,15 @@ namespace nd4j {
_inputTime = time; _inputTime = time;
} }
Nd4jLong NodeProfile::getTemporarySize() { Nd4jLong NodeProfile::getTemporarySize() const{
return _memoryTemporary; return _memoryTemporary;
} }
Nd4jLong NodeProfile::getObjectsSize() { Nd4jLong NodeProfile::getObjectsSize() const{
return _memoryObjects; return _memoryObjects;
} }
Nd4jLong NodeProfile::getTotalSize() { Nd4jLong NodeProfile::getTotalSize() const{
return _memoryTotal; return _memoryTotal;
} }
@ -97,6 +112,18 @@ namespace nd4j {
_memoryTotal = bytes; _memoryTotal = bytes;
} }
Nd4jLong NodeProfile::getExecutionTime() const {
return _executionTime;
}
void NodeProfile::addInputShape(Nd4jLong *shapeInfo) {
_inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
}
void NodeProfile::addOutputShape(Nd4jLong *shapeInfo) {
_outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo));
}
void NodeProfile::merge(NodeProfile *other) { void NodeProfile::merge(NodeProfile *other) {
_merges += other->_merges; _merges += other->_merges;
_memoryObjects += other->_memoryObjects; _memoryObjects += other->_memoryObjects;
@ -110,6 +137,9 @@ namespace nd4j {
_shapeTime += other->_shapeTime; _shapeTime += other->_shapeTime;
_arrayTime += other->_arrayTime; _arrayTime += other->_arrayTime;
_inputTime += other->_inputTime; _inputTime += other->_inputTime;
_inputShapes = other->_inputShapes;
_outputShapes = other->_outputShapes;
} }
std::string& NodeProfile::name() { std::string& NodeProfile::name() {
@ -129,6 +159,9 @@ namespace nd4j {
_shapeTime = other->_shapeTime; _shapeTime = other->_shapeTime;
_arrayTime = other->_arrayTime; _arrayTime = other->_arrayTime;
_inputTime = other->_inputTime; _inputTime = other->_inputTime;
_inputShapes = other->_inputShapes;
_outputShapes = other->_outputShapes;
} }
} }
} }

View File

@ -37,12 +37,13 @@ namespace nd4j {
class ND4J_EXPORT LoopKind { class ND4J_EXPORT LoopKind {
public: public:
enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON}; enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D };
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo); static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
}; };
@ -82,6 +83,57 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd
return COMMON; return COMMON;
} }
LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
auto xRank = shape::rank(xShapeInfo);
auto yRank = shape::rank(yShapeInfo);
auto zRank = shape::rank(zShapeInfo);
auto xOrder = shape::order(xShapeInfo);
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo);
auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zShapeInfo);
bool bNDLoopsRanks = (xRank == zRank && yRank <= xRank && yRank >= 2);
int countUnityDimsInY = 0, countUnityDimsInX = 0;
for (int i = 0; i < xRank; i++) {
if (i < yRank)
countUnityDimsInY += (1 == shape::sizeAt(yShapeInfo, i)) ? 1 : 0;
countUnityDimsInX += (1 == shape::sizeAt(xShapeInfo, i)) ? 1 : 0;
}
bool bNotCommonVectorCase = (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1);
if (3 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
return nd4j::LoopKind::BROADCAST_3D;
if (4 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
return nd4j::LoopKind::BROADCAST_4D;
if (5 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
return nd4j::LoopKind::BROADCAST_5D;
if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) {
// we validate that shapes are equal till the last dim
for (int e = 0; e < xRank - 1; e++) {
if (xShapeInfo[e+1] != yShapeInfo[e+1])
return COMMON;
}
// now, if one of the shapes has 1 as last dim
auto detect = xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0;
if (detect == 1)
return nd4j::LoopKind::BROADCAST_SCALAR_Y;
else if (detect == -1)
return nd4j::LoopKind::BROADCAST_SCALAR_X;
}
return nd4j::LoopKind::COMMON;
}
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) { LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {

View File

@ -51,6 +51,13 @@ namespace nd4j {
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr); static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr);
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr); static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
/**
* allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
*/
static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr); static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr); static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);

View File

@ -50,11 +50,13 @@ namespace nd4j {
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr); static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
// evaluate shapeInfo of permuted array // evaluate shapeInfo of permuted array
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace); // if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace); static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
// evaluate shapeInfo of transposed array // evaluate shapeInfo of transposed array
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace); // if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset); static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
@ -97,6 +99,8 @@ namespace nd4j {
static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo); static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo);
static std::string strideAsString(const NDArray* array); static std::string strideAsString(const NDArray* array);
static std::string shapeInfoAsString(const Nd4jLong* shapeInfo);
static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo); static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo);
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal // evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
@ -176,6 +180,17 @@ namespace nd4j {
return (numStrings + 1) * sizeof(Nd4jLong); return (numStrings + 1) * sizeof(Nd4jLong);
} }
/**
* This method selects strides based on dimentions required for broadcasting
* @param const pointer to input (Y) shape info for strides selection
* @param rank of input (X) to broadcasting
* @param dimentions size
* @param const pointer to dimentions for broadcasting
* @param pointer to output strides have to be pre allocated by 0
* @return
*/
static void copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides);
/* /*
* check whether arr1/arr2 is sub-array of arr2/arr1, * check whether arr1/arr2 is sub-array of arr2/arr1,
* this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1 * this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1

View File

@ -68,7 +68,7 @@ namespace nd4j {
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size(); const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; // shape of sub-arrays (same for all for them)
auto oPtr = new Nd4jLong[numOfSubArrs]; auto oPtr = new Nd4jLong[numOfSubArrs];
if (numOfSubArrs > 0) if (numOfSubArrs > 0)

View File

@ -49,7 +49,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
case nd4j::LoopKind::EWS1: { case nd4j::LoopKind::EWS1: {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
@ -70,7 +70,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
case nd4j::LoopKind::EWSNONZERO: { case nd4j::LoopKind::EWSNONZERO: {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
@ -91,7 +91,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
case nd4j::LoopKind::RANK1: { case nd4j::LoopKind::RANK1: {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
@ -114,7 +114,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
shape::updateStrides(2, tadShape, newStride, 'c'); shape::updateStrides(2, tadShape, newStride, 'c');
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
@ -141,7 +141,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
shape::updateStrides(3, tadShape, newStride, 'c'); shape::updateStrides(3, tadShape, newStride, 'c');
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
@ -170,7 +170,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
shape::updateStrides(4, tadShape, newStride, 'c'); shape::updateStrides(4, tadShape, newStride, 'c');
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
@ -201,7 +201,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
shape::updateStrides(5, tadShape, newStride, 'c'); shape::updateStrides(5, tadShape, newStride, 'c');
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
@ -234,7 +234,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
@ -258,7 +258,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo); const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
@ -284,7 +284,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto tad = const_cast<X *>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);

View File

@ -43,23 +43,30 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
NDArray aPR = a->permute(permutAt); // check whether permutation is necessary
NDArray bPR = b->permute(permutBt); const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
// check whether reshape is necessary // check whether reshape is necessary
if(!aPR.isSameShape(shapeAt)) const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
aPR.reshapei( shapeAt); const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
if(!bPR.isSameShape(shapeBt))
bPR.reshapei( shapeBt);
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0); NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
c->reshapei(outShape); c->reshapei(outShape);
if(aP != aPR)
delete aPR;
if(bP != bPR)
delete bPR;
if(a != aP)
delete aP;
if(b != bP)
delete bP;
return c; return c;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) { void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) {
@ -67,32 +74,38 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
std::vector<Nd4jLong> shapeAt, shapeBt; std::vector<Nd4jLong> shapeAt, shapeBt;
ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt); ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt);
NDArray *cP(c), *cPR(c);
// check whether permutation is required // check whether permutation is required
if(!permutForC.empty()) NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC));
cP = new NDArray(c->permute(permutForC));
auto aPR = a->permute(permutAt); // check whether permutation is necessary
auto bPR = b->permute(permutBt); const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
// check whether reshape is necessary // check whether reshape is necessary
if(!aPR.isSameShape(shapeAt)) const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
aPR.reshapei(shapeAt); const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
if(!bPR.isSameShape(shapeBt))
bPR.reshapei(shapeBt);
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)})) std::vector<Nd4jLong> requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)};
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
mmul(&aPR, &bPR, cPR, 1.0, 0.0); NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false));
mmul(aPR, bPR, cPR, 1.0, 0.0);
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer() if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
cP->assign(cPR); cP->assign(cPR);
if(cPR != c) if(aP != aPR)
delete aPR;
if(bP != bPR)
delete bPR;
if(a != aP)
delete aP;
if(b != bP)
delete bP;
if(cP != cPR)
delete cPR; delete cPR;
if(cP != c) if(c != cP)
delete cP; delete cP;
} }
@ -129,7 +142,7 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
if(!whatToDoWithC.empty()) { if(!whatToDoWithC.empty()) {
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c); cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
for(int i = 0; i < cArrs.size()-1; ++i) for(int i = 0; i < cArrs.size()-1; ++i)
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false)); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
} }
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0); mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
@ -208,7 +221,7 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm // vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
if(isAVector && bRank == 2) { if(isAVector && bRank == 2) {
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M} NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N} NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N}
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N} auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
delete A2; delete A2;
delete C2; delete C2;

View File

@ -139,5 +139,15 @@ namespace nd4j {
return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace); return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace);
} }
////////////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) {
Nd4jLong *outShapeInfo = nullptr;
ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong);
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo);
return outShapeInfo;
}
} }

View File

@ -75,10 +75,23 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
permutBt = axesB; permutBt = axesB;
permutBt.insert(permutBt.end(), list_B.begin(), list_B.end()); permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());
// if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case
uint i1, i2;
for(i1 = 0; i1 < aRank; ++i1)
if(permutAt[i1] != i1)
break;
if(i1 == aRank)
permutAt = {};
for(i2 = 0; i2 < bRank; ++i2)
if(permutBt[i2] != i2)
break;
if(i2 == bRank)
permutBt = {};
Nd4jLong n2 = 1; Nd4jLong n2 = 1;
for (int i = 0; i < axeAsize; i++) for (int i = 0; i < axeAsize; i++)
n2 *= aShapeInfo[axesA[i] + 1]; n2 *= aShapeInfo[axesA[i] + 1];
shapeAt = {-1, n2}; shapeAt = {shape::length(aShapeInfo) / n2, n2};
std::vector<Nd4jLong> oldShapeA; std::vector<Nd4jLong> oldShapeA;
oldShapeA.resize(list_A.size()); oldShapeA.resize(list_A.size());
@ -89,7 +102,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
Nd4jLong n3 = 1; Nd4jLong n3 = 1;
for (int i = 0; i < axeBsize; i++) for (int i = 0; i < axeBsize; i++)
n3 *= bShapeInfo[axesB[i] + 1]; n3 *= bShapeInfo[axesB[i] + 1];
shapeBt = {n3, -1}; shapeBt = {n3, shape::length(bShapeInfo) / n3};
std::vector<Nd4jLong> oldShapeB; std::vector<Nd4jLong> oldShapeB;
oldShapeB.resize(list_B.size()); oldShapeB.resize(list_B.size());
@ -300,32 +313,37 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
return outShape; return outShape;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array // evaluate shapeInfo of permuted array
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) { Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
if (!arr.nonNull()) if (!arr.nonNull())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr!"); throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
if (rank != arr.rankOf()) if (rank != arr.rankOf())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable!"); throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
auto shapeInfoLength = shape::shapeInfoLength(rank); auto shapeInfoLength = shape::shapeInfoLength(rank);
// allocate memory for new array - shapeInfo
// allocate memory for new array - shapeInfo
Nd4jLong *shapeInfoNew = nullptr; Nd4jLong *shapeInfoNew = nullptr;
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong); ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
// copy arr _shapeInfo into new array // copy arr _shapeInfo into new array
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank)); memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
// perform buffer permutation // perform buffer permutation
shape::doPermuteShapeInfo(shapeInfoNew, dimensions); shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());
if(setContigStrides)
shape::updateStrides(shapeInfoNew, arr.ordering());
ShapeDescriptor descriptor(shapeInfoNew); ShapeDescriptor descriptor(shapeInfoNew);
RELEASE(shapeInfoNew, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
RELEASE(shapeInfoNew, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array // evaluate shapeInfo of permuted array
@ -337,14 +355,14 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of transposed array // evaluate shapeInfo of transposed array
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace) { Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
int rank = arr.rankOf(); int rank = arr.rankOf();
std::vector<int> dimensions(rank); std::vector<int> dimensions(rank);
for (int i = 0; i < rank; ++i) for (int i = 0; i < rank; ++i)
dimensions[i] = rank - 1 - i; dimensions[i] = rank - 1 - i;
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace); return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, setContigStrides);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -653,6 +671,26 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd
return result; return result;
} }
std::string ShapeUtils::shapeInfoAsString(const Nd4jLong* shapeInfo) {
if(!shapeInfo)
throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !");
std::string result;
int len = shape::shapeInfoLength(shapeInfo[0]);
result.append("[");
for (int e = 0; e < len; e++) {
result += flatbuffers::NumToString(shapeInfo[e]);
if (e < len - 1)
result.append(", ");
}
result.append("]");
return result;
}
std::string ShapeUtils::shapeAsString(const int rank, const Nd4jLong* shapeInfo) { std::string ShapeUtils::shapeAsString(const int rank, const Nd4jLong* shapeInfo) {
if(!shapeInfo) if(!shapeInfo)
@ -1019,6 +1057,29 @@ std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const
return numOfMinTads == 1 ? maxTadDims : std::vector<int>(); return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
} }
void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides) {
int yRank = shape::rank(inShapeInfo);
auto yOrigStride = shape::stride(inShapeInfo);
if (yRank == nRank) {
for (int i = 0; i < yRank; ++i) {
// x[2,3,4] * y[2,1,4] = z[2,3,4]
outStrides[i] = (1 == shape::sizeAt(inShapeInfo, i)) ? 0 : yOrigStride[i];
}
}
else {
auto dimEx = nd4j::ShapeUtils::evalDimsToExclude(nRank, dimsSize, dims);
for (int i = 0, it = 0; i < nRank; ++i) {
auto nCount = std::count(dimEx.cbegin(), dimEx.cend(), i);
outStrides[i] = (0 == nCount) ? yOrigStride[it++] : 0;
if (it == yRank)
break;
}
}
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
/* /*
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) { bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {

File diff suppressed because it is too large Load Diff

View File

@ -40,6 +40,7 @@
#endif #endif
#include <helpers/TAD.h> #include <helpers/TAD.h>
#include <helpers/LoopKind.h>
#include "legacy_ops.h" #include "legacy_ops.h"
@ -122,6 +123,7 @@ namespace functions {
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ, Nd4jLong *tadOffsetZ,
nd4j::LoopKind::Kind loopKind,
uint64_t start, uint64_t start,
uint64_t stop); uint64_t stop);
@ -149,6 +151,7 @@ namespace functions {
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ, Nd4jLong *tadOffsetZ,
nd4j::LoopKind::Kind loopKind,
uint64_t start, uint64_t start,
uint64_t stop); uint64_t stop);

View File

@ -14,9 +14,9 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
// //
// @author Yurii Shyrma (iuriish@yahoo.com) // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <loops/TrueBroadcastHelper.h> #include <loops/TrueBroadcastHelper.h>
#include <ops/ops.h> #include <ops/ops.h>
@ -25,12 +25,12 @@
using namespace simdOps; using namespace simdOps;
namespace nd4j { namespace nd4j {
namespace helpers { namespace helpers {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer()); const X* x = reinterpret_cast<X*>(xArr.getBuffer());
@ -45,11 +45,11 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
const int yRank = yArr.rankOf(); const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf(); const int zRank = zArr.rankOf();
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank && bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() &&
1 == yArr.ews() && 'c' == yArr.ordering() && 1 == yArr.ews() && 'c' == yArr.ordering() &&
1 == zArr.ews() && 'c' == zArr.ordering()); 1 == zArr.ews() && 'c' == zArr.ordering());
if (bSpecialCase) { if (bSpecialCase && yArr.isColumnVector() && 1 == xArr.sizeAt(-1) ) {
auto yLen = (uint32_t)yArr.lengthOf(); auto yLen = (uint32_t)yArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{ auto func = PRAGMA_THREADS_FOR{
for (uint32_t i = start; i < stop; i++) { for (uint32_t i = start; i < stop; i++) {
@ -64,8 +64,44 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
return; return;
} }
auto yShapeInt = yArr.getShapeAsVectorInt();
auto xShapeInt = xArr.getShapeAsVectorInt();
auto nCountY = std::count_if(yShapeInt.cbegin(), yShapeInt.cend(), [](int i) { return i == 1; });
auto nCountX = std::count_if(xShapeInt.cbegin(), xShapeInt.cend(), [](int i) { return i == 1; });
bool bSpecialCase2 = (xRank == zRank && yRank == zRank && 1 == xArr.sizeAt(-1) && 1 == yArr.sizeAt(-2) && 1 == nCountY && 1 == nCountX);
if (bSpecialCase && bSpecialCase2) {
int zDim1 = zArr.sizeAt(-2);
int zDim2 = zArr.sizeAt(-1);
int nLen = zArr.lengthOf() / yArr.sizeAt(-1);
auto func = PRAGMA_THREADS_FOR{
for (uint32_t total = start; total < stop; total++) {
uint32_t i = total / zDim1;
uint32_t j = total % zDim1;
uint32_t index = (i * zDim1) + j;
auto rZ = z + (index * zDim2);
auto rY = y + (i * zDim2);
auto rX = x[index];
for (uint32_t n = 0; n < zDim2; n++) {
rZ[n] = OpType::op(rX, rY[n]);
}
}
};
samediff::Threads::parallel_tad(func, 0, nLen, 1);
return;
}
const Nd4jLong zLen = zArr.lengthOf(); const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) { for (auto i = start; i < stop; ++i) {
@ -77,7 +113,8 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
if (ix >= 0) { if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz]; xCoords[ix--] = zCoords[iz];
} else { }
else {
xCoords[ix--] = 0; xCoords[ix--] = 0;
} }
} }
@ -85,7 +122,8 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
if (iy >= 0) { if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz]; yCoords[iy--] = zCoords[iz];
} else { }
else {
yCoords[iy--] = 0; yCoords[iy--] = 0;
} }
} }
@ -100,17 +138,17 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
}; };
samediff::Threads::parallel_for(func, 0, zLen); samediff::Threads::parallel_for(func, 0, zLen);
} }
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS); DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer()); const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer()); const X* y = reinterpret_cast<X*>(yArr.getBuffer());
@ -126,7 +164,7 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
const Nd4jLong zLen = zArr.lengthOf(); const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) { for (auto i = start; i < stop; ++i) {
@ -138,7 +176,8 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
if (ix >= 0) { if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz]; xCoords[ix--] = zCoords[iz];
} else { }
else {
xCoords[ix--] = 0; xCoords[ix--] = 0;
} }
} }
@ -146,7 +185,8 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
if (iy >= 0) { if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz]; yCoords[iy--] = zCoords[iz];
} else { }
else {
yCoords[iy--] = 0; yCoords[iy--] = 0;
} }
} }
@ -161,17 +201,17 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
}; };
samediff::Threads::parallel_for(func, 0, zLen); samediff::Threads::parallel_for(func, 0, zLen);
} }
template <typename X, typename Y> template <typename X, typename Y>
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename X> template <typename X>
template<typename OpType> template<typename OpType>
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer()); const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer()); const X* y = reinterpret_cast<X*>(yArr.getBuffer());
@ -187,7 +227,7 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
const Nd4jLong zLen = zArr.lengthOf(); const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) { for (auto i = start; i < stop; ++i) {
@ -199,7 +239,8 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
if (ix >= 0) { if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz]; xCoords[ix--] = zCoords[iz];
} else { }
else {
xCoords[ix--] = 0; xCoords[ix--] = 0;
} }
} }
@ -207,7 +248,8 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
if (iy >= 0) { if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz]; yCoords[iy--] = zCoords[iz];
} else { }
else {
yCoords[iy--] = 0; yCoords[iy--] = 0;
} }
} }
@ -222,28 +264,28 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
}; };
samediff::Threads::parallel_for(func, 0, zLen); samediff::Threads::parallel_for(func, 0, zLen);
} }
template <typename X> template <typename X>
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS); DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
} }
/* /*
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES); BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
*/ */
} }
} }

View File

@ -25,6 +25,7 @@
#include <LoopKind.h> #include <LoopKind.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h> #include <execution/Threads.h>
#include <helpers/ShapeUtils.h>
using namespace simdOps; using namespace simdOps;
@ -75,6 +76,7 @@ namespace functions {
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset, Nd4jLong *zTadOffset,
nd4j::LoopKind::Kind loopKind,
uint64_t start, uint64_t start,
uint64_t stop) { uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
@ -88,7 +90,7 @@ namespace functions {
xTadShapeInfo, xTadShapeInfo,
xTadOffset, xTadOffset,
zTadShapeInfo, zTadShapeInfo,
zTadOffset, start, stop), BROADCAST_OPS); zTadOffset, loopKind, start, stop), BROADCAST_OPS);
} }
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
@ -105,6 +107,7 @@ namespace functions {
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset, Nd4jLong *zTadOffset,
nd4j::LoopKind::Kind loopKind,
uint64_t start, uint64_t start,
uint64_t stop) { uint64_t stop) {
@ -142,7 +145,14 @@ namespace functions {
auto yEws = shape::elementWiseStride(yShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zTadShapeInfo); auto zEws = shape::elementWiseStride(zTadShapeInfo);
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
const nd4j::LoopKind::Kind kindOfLoop =
(loopKind == nd4j::LoopKind::BROADCAST_SCALAR_X ||
loopKind == nd4j::LoopKind::BROADCAST_SCALAR_Y ||
loopKind == nd4j::LoopKind::BROADCAST_3D ||
loopKind == nd4j::LoopKind::BROADCAST_4D ||
loopKind == nd4j::LoopKind::BROADCAST_5D)
? loopKind : nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
for (auto i = start; i < stop; i++) { for (auto i = start; i < stop; i++) {
@ -163,6 +173,131 @@ namespace functions {
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
} }
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_X){
// this loop effectively turns broadcast into series of scalar ops
auto loopLength = yShapeInfo[shape::rank(yShapeInfo)];
for (auto i = start; i < stop; i++) {
auto oY = y + (i * loopLength);
auto oZ = z + (i * loopLength);
const auto oX = x[i];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < loopLength; f++)
oZ[f] = OpType::op(oX, oY[f]);
}
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_Y){
// this loop effectively turns broadcast into series of scalar ops
auto loopLength = xShapeInfo[shape::rank(xShapeInfo)];
for (auto i = start; i < stop; i++) {
auto oX = x + (i * loopLength);
auto oZ = z + (i * loopLength);
const auto oY = y[i];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < loopLength; f++)
oZ[f] = OpType::op(oX[f], oY);
}
}
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_3D) {
int xRank = shape::rank(xShapeInfo);
int yRank = shape::rank(yShapeInfo);
auto xStrides = shape::stride(xShapeInfo);
auto zStrides = shape::stride(zShapeInfo);
Nd4jLong yStrides[3] = { 0,0,0 };
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
for (uint32_t index0 = start; index0 < stop; index0++) {
PRAGMA_OMP_SIMD
for (uint32_t index1 = 0; index1 < nSize1; index1++) {
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2);
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2);
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2);
*rZ = OpType::op(*rX, *rY);
}
}
}
}
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_4D) {
int xRank = shape::rank(xShapeInfo);
int yRank = shape::rank(yShapeInfo);
auto xStrides = shape::stride(xShapeInfo);
auto zStrides = shape::stride(zShapeInfo);
Nd4jLong yStrides[4] = { 0,0,0,0 };
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3);
for (uint32_t i = start; i < stop; i++) {
uint32_t index0 = i / nSize1;
uint32_t index1 = i % nSize1;
PRAGMA_OMP_SIMD
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
for (uint32_t index3 = 0; index3 < nSize3; index3++) {
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3);
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3);
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3);
*rZ = OpType::op(*rX, *rY);
}
}
}
}
else if (kindOfLoop == nd4j::LoopKind::BROADCAST_5D) {
int xRank = shape::rank(xShapeInfo);
int yRank = shape::rank(yShapeInfo);
auto xStrides = shape::stride(xShapeInfo);
auto zStrides = shape::stride(zShapeInfo);
Nd4jLong yStrides[5] = { 0,0,0,0,0 };
nd4j::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides);
uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1);
uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2);
uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3);
uint32_t nSize4 = shape::sizeAt(zShapeInfo, 4);
for (uint32_t i = start; i < stop; i++) {
uint32_t index0 = i / nSize1;
uint32_t index1 = i % nSize1;
PRAGMA_OMP_SIMD
for (uint32_t index2 = 0; index2 < nSize2; index2++) {
for (uint32_t index3 = 0; index3 < nSize3; index3++) {
for (uint32_t index4 = 0; index4 < nSize4; index4++) {
auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3 + xStrides[4] * index4);
auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3 + yStrides[4] * index4);
auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3 + zStrides[4] * index4);
*rZ = OpType::op(*rX, *rY);
}
}
}
}
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];

View File

@ -73,7 +73,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
intermediatery[thread_id] = OpType::startingIndexValue(x); intermediatery[thread_id] = OpType::startingIndexValue(x);
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
IndexValue<X> curr(x[i], i); IndexValue<X> curr(x[i], i);
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams); intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
} }
@ -88,7 +88,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
intermediatery[thread_id] = OpType::startingIndexValue(x); intermediatery[thread_id] = OpType::startingIndexValue(x);
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
IndexValue<X> curr(x[offset], i); IndexValue<X> curr(x[offset], i);
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams); intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);

View File

@ -75,7 +75,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
} }
@ -93,7 +93,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) { for (uint64_t i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
@ -111,7 +111,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) { for (uint64_t i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments);
@ -129,7 +129,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) { for (uint64_t i = start; i < stop; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments);
@ -149,7 +149,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) { for (uint64_t i = start; i < stop; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
@ -197,7 +197,7 @@ namespace functions {
else{ else{
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) { for (uint64_t i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments);
} }
@ -213,7 +213,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) { for (uint64_t i = start; i < stop; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments); z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments);
@ -255,7 +255,7 @@ namespace functions {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) { for (uint64_t i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[offset] = OpClass::op(i, length, rng, extraArguments); z[offset] = OpClass::op(i, length, rng, extraArguments);
} }

View File

@ -88,7 +88,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
} }
}; };
@ -98,7 +98,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { } else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
} }
@ -110,7 +110,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);

View File

@ -158,7 +158,7 @@ namespace functions {
const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeShapeInfo, tadShapeShapeInfoCast); const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeShapeInfo, tadShapeShapeInfoCast);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto r = start; r < stop; r += increment) { for (auto r = start; r < stop; r++) {
auto tadOffsetForBlock = tadPack.primaryOffsets()[r]; auto tadOffsetForBlock = tadPack.primaryOffsets()[r];
auto tx = x + tadOffsetForBlock; auto tx = x + tadOffsetForBlock;

View File

@ -84,7 +84,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x; int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) { if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads) for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params); z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -89,7 +89,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x; int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) { if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads) for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params); z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -97,7 +97,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x; int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) { if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (Nd4jLong i = tid; i < length; i += totalThreads) for (Nd4jLong i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params); z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -87,7 +87,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x; int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) { if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads) for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params); z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -89,7 +89,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x; int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) { if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads) for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params); z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -81,7 +81,7 @@ namespace nd4j {
// now we actually apply quantization // now we actually apply quantization
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
rz[e] = static_cast<char>(nd4j::math::nd4j_round<float, char>( 1.0f * static_cast<float>(x[e]) / nd4j::math::nd4j_max<float>(amax, amin) * max_byte)); rz[e] = static_cast<char>(nd4j::math::nd4j_round<float, char>( 1.0f * static_cast<float>(x[e]) / nd4j::math::nd4j_max<float>(amax, amin) * max_byte));
} }
}; };
@ -177,7 +177,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
int flimit = limit + 4; int flimit = limit + 4;
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e++) {
int el = x[e]; int el = x[e];
int ael = nd4j::math::nd4j_abs<int>(el) - 1; int ael = nd4j::math::nd4j_abs<int>(el) - 1;
z[ael] += el > 0 ? static_cast<T>(threshold) : static_cast<T>(-threshold); z[ael] += el > 0 ? static_cast<T>(threshold) : static_cast<T>(-threshold);
@ -202,7 +202,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
auto z = reinterpret_cast<T *>(dz); auto z = reinterpret_cast<T *>(dz);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i++) {
z[i] = static_cast<T>(static_cast<float>(x[i])); z[i] = static_cast<T>(static_cast<float>(x[i]));
} }
}; };

View File

@ -147,6 +147,9 @@ namespace nd4j {
// returns TRUE if this op allows in-place execution // returns TRUE if this op allows in-place execution
bool allowsInplace(); bool allowsInplace();
// this method allows you to enable/disable inplace call for a given op
void allowInplace(bool reallyAllow);
// this method returns opNum (applicable for legacy XYZ ops only) // this method returns opNum (applicable for legacy XYZ ops only)
int getOpNum(); int getOpNum();

View File

@ -27,13 +27,11 @@ namespace nd4j {
namespace ops { namespace ops {
OP_IMPL(identity, 1, 1, true) { OP_IMPL(identity, 1, 1, true) {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = this->getZ(block); auto z = OUTPUT_VARIABLE(0);
// just for lulz if (!block.isInplace())
first->applyTransform(nd4j::transform::Identity, *z); first->applyTransform(nd4j::transform::Identity, *z);
STORE_RESULT(*z);
return Status::OK(); return Status::OK();
} }
DECLARE_SYN(linear, identity); DECLARE_SYN(linear, identity);
@ -60,8 +58,8 @@ namespace nd4j {
DECLARE_TYPES(identity_bp) { DECLARE_TYPES(identity_bp) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, DataType::ANY) ->setAllowedInputTypes(0, DataType::ANY)
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) ->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); ->setAllowedOutputTypes(0, {ALL_FLOATS});
} }
} }
} }

View File

@ -29,7 +29,9 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { //////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
@ -57,58 +59,41 @@ namespace nd4j {
const int yLastButOneDim = transY ? -1 : -2; const int yLastButOneDim = transY ? -1 : -2;
// ******* input validation ******* // // ******* input validation ******* //
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank);
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
xRank, yRank);
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1) if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", x->lengthOf(), y->lengthOf());
"MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",
x->lengthOf(), y->lengthOf());
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
"MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
"MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else { } else {
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank);
"MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
xRank, yRank, zRank);
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) &&
x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0,
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
ShapeUtils::shapeAsString(z).c_str());
if (xRank > 2) // outer dims must be the same if (xRank > 2) // outer dims must be the same
for (int i = 0; i < xRank - 2; ++i) for (int i = 0; i < xRank - 2; ++i)
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
ShapeUtils::shapeAsString(z).c_str());
} }
// ******* end of input validation ******* // // ******* end of input validation ******* //
MmulHelper::matmul(x, y, z, transX, transY); MmulHelper::matmul(x, y, z, transX, transY);
return Status::OK(); return Status::OK();
} }
DECLARE_SYN(mMul, matmul); DECLARE_SYN(mMul, matmul);
DECLARE_SYN(mmul, matmul); DECLARE_SYN(mmul, matmul);
DECLARE_SYN(gemm, matmul); DECLARE_SYN(gemm, matmul);
DECLARE_SYN(gemv, matmul); DECLARE_SYN(gemv, matmul);
DECLARE_SYN(dot, matmul); DECLARE_SYN(dot, matmul);
//////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(matmul) { DECLARE_SHAPE_FN(matmul) {
auto xShapeInfo = inputShape->at(0); auto xShapeInfo = inputShape->at(0);
auto yShapeInfo = inputShape->at(1); auto yShapeInfo = inputShape->at(1);
@ -144,17 +129,18 @@ namespace nd4j {
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly); auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
return SHAPELIST(newShape); return SHAPELIST(newShape);
} }
DECLARE_TYPES(matmul) { //////////////////////////////////////////////////////////////////////
DECLARE_TYPES(matmul) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS}) ->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS}) ->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS}); ->setAllowedOutputTypes(0, {ALL_FLOATS});
} }
//////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) { CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto eps = INPUT_VARIABLE(2); auto eps = INPUT_VARIABLE(2);
@ -182,10 +168,10 @@ F F T [a,b] [b,c] [c,a] [c,a]
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {}); op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
return Status::OK(); return Status::OK();
} }
//////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(matmul_bp) { DECLARE_SHAPE_FN(matmul_bp) {
Nd4jLong *xShapeInfo; Nd4jLong *xShapeInfo;
Nd4jLong *yShapeInfo; Nd4jLong *yShapeInfo;
@ -193,18 +179,19 @@ F F T [a,b] [b,c] [c,a] [c,a]
COPY_SHAPE(inputShape->at(1), yShapeInfo); COPY_SHAPE(inputShape->at(1), yShapeInfo);
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo)); return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
} }
DECLARE_TYPES(matmul_bp) { //////////////////////////////////////////////////////////////////////
DECLARE_TYPES(matmul_bp) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS}) ->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS}) ->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_FLOATS}) ->setAllowedInputTypes(2, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS}) ->setAllowedOutputTypes(0, {ALL_FLOATS})
->setAllowedOutputTypes(1, {ALL_FLOATS}); ->setAllowedOutputTypes(1, {ALL_FLOATS});
} }
} }
} }

View File

@ -21,17 +21,22 @@
#include <op_boilerplate.h> #include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_tensormmul) #if NOT_EXCLUDED(OP_tensormmul)
#include <numeric>
#include <helpers/ShapeUtils.h> #include <helpers/ShapeUtils.h>
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <MmulHelper.h> #include <MmulHelper.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
auto a = INPUT_VARIABLE(0); auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1); auto b = INPUT_VARIABLE(1);
auto c = OUTPUT_VARIABLE(0); // auto c = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same"); REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
@ -40,20 +45,20 @@ namespace nd4j {
int axe1_size = INT_ARG(axe0_size+1); int axe1_size = INT_ARG(axe0_size+1);
std::vector<int> axes_0(axe0_size), axes_1(axe1_size); std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
for (int e = 0; e < axe0_size; e++) for (int e = 0; e < axe0_size; e++)
axes_0[e] = (int) INT_ARG(e+1); axes_0[e] = (int)INT_ARG(e + 1);
for (int e = 0; e < axe1_size; e++) for (int e = 0; e < axe1_size; e++)
axes_1[e] = (int) INT_ARG(e + axe0_size + 2); axes_1[e] = (int)INT_ARG(e + axe0_size + 2);
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size()); nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
MmulHelper::tensorDot(a, b, c, axes_0, axes_1); MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
return Status::OK(); return Status::OK();
} }
DECLARE_SYN(tensordot, tensormmul); DECLARE_SYN(tensordot, tensormmul);
////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(tensormmul) { DECLARE_SHAPE_FN(tensormmul) {
auto aShapeInfo = inputShape->at(0); auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1); auto bShapeInfo = inputShape->at(1);
@ -76,15 +81,114 @@ namespace nd4j {
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape))); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
} }
DECLARE_TYPES(tensormmul) { ////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(tensormmul) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
}
////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
auto A = INPUT_VARIABLE(0);
auto B = INPUT_VARIABLE(1);
auto dLdC = INPUT_VARIABLE(2);
auto dLdA = OUTPUT_VARIABLE(0);
auto dLdB = OUTPUT_VARIABLE(1);
REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
int axe0Size = INT_ARG(0);
int axe1Size = INT_ARG(axe0Size + 1);
auto Arank = A->rankOf();
auto Brank = B->rankOf();
auto dLdCrank = dLdC->rankOf();
REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0");
REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1");
// building axes
std::vector<int> axes0(axe0Size), axes1(axe1Size);
for (uint e = 0; e < axe0Size; e++)
axes0[e] = (int)INT_ARG(e + 1);
for (uint e = 0; e < axe1Size; e++)
axes1[e] = (int)INT_ARG(e + axe0Size + 2);
std::vector<int> permutAt, permutBt;
std::vector<Nd4jLong> shapeAt, shapeBt;
ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt);
// special case for scalar value
if (dLdC->isScalar()) {
dLdA->assign((*dLdC) * *B);
dLdB->assign((*dLdC) * *A);
return Status::OK();
} }
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
// rank always have to be divided by 2
std::vector<int> axesAdLdC, axesBdLdC;
if (dLdCrank > 1) {
axesAdLdC.resize(dLdCrank / 2);
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
} }
else {
axesAdLdC.push_back(0);
axesBdLdC.push_back(0);
}
// calculate dLdA
MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt);
// calculate dLdB
MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt);
return Status::OK();
}
////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(tensormmul_bp) {
auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1);
auto dLShapeInfo = inputShape->at(2);
REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) &&
(ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
Nd4jLong* dLdAShapeInfo = nullptr;
Nd4jLong* dLdBShapeInfo = nullptr;
COPY_SHAPE(aShapeInfo, dLdAShapeInfo);
COPY_SHAPE(bShapeInfo, dLdBShapeInfo);
return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo));
}
////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(tensormmul_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS
->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF });
}
}
} }
#endif #endif

View File

@ -79,7 +79,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
} }
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput); auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput); auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput, false);
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
nd4j::ops::conv2d conv2d; nd4j::ops::conv2d conv2d;
@ -216,10 +216,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
} }
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput); auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput); auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput, false);
auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO); auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO);
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
nd4j::ops::conv2d_bp conv2dBP; nd4j::ops::conv2d_bp conv2dBP;
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});

View File

@ -239,7 +239,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
//----- calculation of gradO -----// //----- calculation of gradO -----//
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;

View File

@ -233,7 +233,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()})); gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;

View File

@ -243,7 +243,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;

View File

@ -31,22 +31,17 @@ namespace nd4j {
REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf()); REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf());
REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf()); REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf());
REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1)); REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1));
REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", x->sizeAt(1), w->sizeAt(0));
x->sizeAt(1), w->sizeAt(0));
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
//T bound = (T)0.f;
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
nd4j::ops::xw_plus_b op; nd4j::ops::xw_plus_b op;
std::unique_ptr<ResultSet> result(op.evaluate({x, w, b})); auto status = op.execute({x, w, b}, {output});
REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data."); REQUIRE_TRUE(Status::OK() == status, 0, "relu_layer: xw_plus_b op failed on input data.");
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
auto xw = result->at(0); output->applyScalar(nd4j::scalar::RELU, scalar, *output);
xw->applyScalar(nd4j::scalar::RELU, scalar, *output);
return Status::OK(); return Status::OK();
} }

View File

@ -23,7 +23,8 @@
//#include <ops/declarable/headers/parity_ops.h> //#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/image_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(crop_and_resize, 4, 1, false, 0, 0) { CUSTOM_OP_IMPL(crop_and_resize, 4, 1, false, 0, 0) {

View File

@ -61,7 +61,7 @@ namespace nd4j {
} }
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target); return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
} }

View File

@ -62,7 +62,7 @@ namespace nd4j {
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false"); REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target); return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
} }

View File

@ -43,7 +43,7 @@ namespace nd4j {
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
if (block.width() > 1) { if (block.width() > 1) {
auto newImageSize = INPUT_VARIABLE(1); auto newImageSize = INPUT_VARIABLE(1);

View File

@ -63,7 +63,7 @@ namespace nd4j {
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height); REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target); return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
} }

View File

@ -47,11 +47,12 @@ namespace nd4j {
shape.insert(shape.begin() + axis, 1); shape.insert(shape.begin() + axis, 1);
if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) {
output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset());
} else {
auto tmp = input->reshape(input->ordering(), shape); auto tmp = input->reshape(input->ordering(), shape);
output->assign(tmp); output->assign(tmp);
}
STORE_RESULT(output);
return Status::OK(); return Status::OK();
} }

View File

@ -15,7 +15,8 @@
******************************************************************************/ ******************************************************************************/
// //
// Created by raver119 on 29/10/17. // @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <op_boilerplate.h> #include <op_boilerplate.h>
@ -29,80 +30,52 @@ namespace nd4j {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// here iArgs is int vector of ordered set of dimensions to be permuted // here iArgs is int vector of ordered set of dimensions to be permuted
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) { CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
bool replace = false; if (x->isEmpty()) {
REQUIRE_TRUE(z->isEmpty(), 0, "PERMUTE OP: when input is empty, output must also be empty");
auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments(); return Status::OK(); //No op
std::vector<int> arguments({});
if(origArgs.size() > 0){
for (int e = 0; e < origArgs.size(); e++) {
int ax = origArgs[e];
if (ax < 0)
ax += x->rankOf();
arguments.emplace_back(ax);
} }
replace = true; if (block.width() == 1 && block.getIArguments()->size() == 0) {
} else { z->assign(x->transpose());
for (int e = x->rankOf() - 1; e >= 0; e--)
arguments.emplace_back(e);
}
// 0D edge case
if (x->rankOf() == 0) {
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(x);
return Status::OK(); return Status::OK();
} }
if(block.isInplace()) { // in-place std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
x->permutei(arguments);
STORE_RESULT(x); z->assign(x->permute(permutationVector));
} else {
auto output = OUTPUT_VARIABLE(0);
auto result = x->permute(arguments);
output->assign(result);
STORE_RESULT(output);
}
return Status::OK(); return Status::OK();
} }
DECLARE_TYPES(permute) { //////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(permute) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY) ->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(1, {ALL_INTS})
->setSameMode(true); ->setSameMode(true);
} }
DECLARE_SHAPE_FN(permute) { //////////////////////////////////////////////////////////////////////////
auto shapeList = SHAPELIST(); DECLARE_SHAPE_FN(permute) {
auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
if (shape::rank(inputShape->at(0)) == 0) { auto x = INPUT_VARIABLE(0);
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
} else if (inputShape->size() == 1 && !arguments.empty()) {
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
} else {
if(arguments.size() == 0){
//Reverse dimensions
int rank = shape::rank(inputShape->at(0));
for (int e = rank - 1; e >= 0; e--)
arguments.emplace_back(e);
}
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace())); if (block.width() == 1 && block.getIArguments()->size() == 0)
} return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
return shapeList; std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
}
} auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
return SHAPELIST(outputShapeInfo);
}
}
} }
#endif #endif

View File

@ -24,23 +24,29 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
//////////////////////////////////////////////////////////////////////////
// here iArgs is a vector with (optional) negative of order as first element:
// ({-order, dim1, dim2, dim3, ...})
CUSTOM_OP_IMPL(reshape, 1, 1, true, 0, -2) {
auto x = INPUT_VARIABLE(0);
if (block.width() == 1) { //////////////////////////////////////////////////////////////////////////
auto arguments = block.getIArguments(); // here iArgs is a vector with (optional) negative of order as first element:
int argsSize = arguments->size(); // ({-order, dim1, dim2, dim3, ...})
CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
//Special case: empty.reshape(<other empty shape>) -> return empty //Special case: empty.reshape(<other empty shape>) -> return empty
if (x->isEmpty()) { if (x->isEmpty()) {
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
return ND4J_STATUS_OK; //No op return Status::OK(); //No op
} }
if (block.width() == 1) {
auto arguments = block.getIArguments();
int argsSize = arguments->size();
int e = 1; int e = 1;
char order = (char) -(*arguments)[0]; char order = (char) -(*arguments)[0];
if (order != 'c' && order != 'f') { if (order != 'c' && order != 'f') {
@ -77,28 +83,15 @@ namespace nd4j {
nd4j_printv("Reshape: new shape", shapeNew); nd4j_printv("Reshape: new shape", shapeNew);
} }
if (block.isInplace()) {
if (x->reshapei(order, shapeNew)) {
STORE_RESULT(*x);
return ND4J_STATUS_OK;
}
} else {
auto ret = OUTPUT_VARIABLE(0);
auto xr = x->reshape(order, shapeNew); auto xr = x->reshape(order, shapeNew);
ret->assign(xr); z->assign(xr);
STORE_RESULT(*ret); STORE_RESULT(*z);
return Status::OK(); return Status::OK();
}
} else if (block.width() == 2) {
auto s = INPUT_VARIABLE(1);
//Special case: empty.reshape(-1) -> return empty } else if (block.width() == 2) {
if (x->isEmpty()) {
//REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); auto s = INPUT_VARIABLE(1);
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
return Status::OK(); //No op
}
char order = 'c'; char order = 'c';
if (block.numI() > 0) if (block.numI() > 0)
@ -129,37 +122,30 @@ namespace nd4j {
nd4j_printv("Reshape: new shape", shapeNew); nd4j_printv("Reshape: new shape", shapeNew);
} }
if (block.isInplace()) {
if (x->reshapei(order, shapeNew)) {
STORE_RESULT(*x);
return Status::OK();
}
} else {
auto ret = OUTPUT_VARIABLE(0);
if (s->isEmpty()) { if (s->isEmpty()) {
// just a scalar // just a scalar
ret->assign(x); z->assign(x);
} else { } else {
auto xr = x->reshape(order, shapeNew); auto xr = x->reshape(order, shapeNew);
ret->assign(xr); z->assign(xr);
} }
return Status::OK(); return Status::OK();
}
} }
return ND4J_STATUS_BAD_INPUT; return ND4J_STATUS_BAD_INPUT;
} }
DECLARE_TYPES(reshape) { DECLARE_TYPES(reshape) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY) ->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(1, {ALL_INTS})
->setSameMode(true); ->setSameMode(true);
} }
DECLARE_SHAPE_FN(reshape) { DECLARE_SHAPE_FN(reshape) {
auto inp = inputShape->at(0); auto inp = inputShape->at(0);
// we can launch op using Int arguments // we can launch op using Int arguments
@ -270,8 +256,8 @@ namespace nd4j {
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew)); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
} }
} }
} }
} }
#endif #endif

View File

@ -28,18 +28,16 @@ namespace nd4j {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(reshapeas, 2, 1, true, 0, 0) { CUSTOM_OP_IMPL(reshapeas, 2, 1, false, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
std::vector<Nd4jLong> shapeNew(y->shapeOf(), y->shapeOf() + y->rankOf());
char order = y->ordering();
if (x->reshapei(order, shapeNew)) { if (x->reshapei(y->ordering(), y->getShapeAsVector())) {
*z = *x;
STORE_RESULT(*z); z->assign(x);
return Status::OK(); return Status::OK();
} }
@ -49,14 +47,8 @@ namespace nd4j {
DECLARE_SHAPE_FN(reshapeas) { DECLARE_SHAPE_FN(reshapeas) {
auto inputShapeInfo = inputShape->at(1); return SHAPELIST(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->getShapeInfo(), false, block.workspace()));
int shapeInfoLength = inputShapeInfo[0]*2 + 4; }
Nd4jLong* outputShapeInfo(nullptr);
COPY_SHAPE(inputShapeInfo, outputShapeInfo);
return SHAPELIST(CONSTANT(outputShapeInfo));
}
DECLARE_TYPES(reshapeas) { DECLARE_TYPES(reshapeas) {
getOpDescriptor() getOpDescriptor()

View File

@ -25,7 +25,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(squeeze, 1, 1, true, 0, -2) { CUSTOM_OP_IMPL(squeeze, 1, 1, false, 0, -2) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
@ -71,11 +71,15 @@ namespace nd4j {
} }
if (block.isInplace()) { if (block.isInplace()) {
output->reshapei(input->ordering(), shape); output->reshapei(input->ordering(), shape, false);
} else {
if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) {
output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset());
} else { } else {
auto tmp = input->reshape(input->ordering(), shape); auto tmp = input->reshape(input->ordering(), shape);
output->assign(tmp); output->assign(tmp);
} }
}
return Status::OK(); return Status::OK();
} }

View File

@ -25,7 +25,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(tile_to_shape, 1, 1, true, 0, -1) { CUSTOM_OP_IMPL(tile_to_shape, 1, 1, false, 0, -1) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);

View File

@ -15,7 +15,8 @@
******************************************************************************/ ******************************************************************************/
// //
// Created by raver119 on 29/10/17. // @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <op_boilerplate.h> #include <op_boilerplate.h>
@ -27,111 +28,50 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(transpose, 1, 1, true, 0, 0) { CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
if (block.width() == 1) { auto z = OUTPUT_VARIABLE(0);
if (block.isInplace()) {
x->transposei();
STORE_RESULT(*x);
} else {
auto output = OUTPUT_VARIABLE(0);
auto t = x->transpose();
output->assign(t);
STORE_RESULT(*output);
}
} else {
// this is tf-mode transpose, that's nd4j permute
bool replace = false;
std::vector<int> arguments(*block.getIArguments());
auto w = block.width(); //Special case: empty.reshape(<other empty shape>) -> return empty
auto a = arguments.size(); if (x->isEmpty()) {
REQUIRE_TRUE(z->isEmpty(), 0, "TRANSPOSE OP: when input is empty, output must also be empty");
if (w == 2 && a == 0) { return Status::OK(); //No op
auto axis = INPUT_VARIABLE(1);
for (int e = 0; e < axis->lengthOf(); e++) {
auto ax = axis->e<int>(e);
if (ax < 0)
ax += x->rankOf();
arguments.emplace_back(ax);
} }
replace = true; if (block.width() == 1 && block.getIArguments()->size() == 0) {
} else if (a == 0) { z->assign(x->transpose());
for (int e = x->rankOf() - 1; e >= 0; e--)
arguments.emplace_back(e);
}
// 0D edge case
if (x->rankOf() == 0) {
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(x);
return Status::OK(); return Status::OK();
} }
if(block.isInplace()) { // in-place std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
x->permutei(arguments);
STORE_RESULT(x); z->assign(x->permute(permutationVector));
} else {
auto input = x->permute(arguments);
auto output = OUTPUT_VARIABLE(0);
output->assign(input);
}
}
return Status::OK(); return Status::OK();
} }
DECLARE_TYPES(transpose) { DECLARE_TYPES(transpose) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(nd4j::DataType::ANY)
->setSameMode(true); ->setSameMode(true);
} }
DECLARE_SHAPE_FN(transpose) {
auto x = INPUT_VARIABLE(0);
if (block.width() == 1 && block.getIArguments()->size() == 0)
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
DECLARE_SHAPE_FN(transpose) {
if (block.width() == 1) {
auto outputShapeInfo = ShapeUtils::evalTranspShapeInfo(*INPUT_VARIABLE(0), block.workspace());
return SHAPELIST(outputShapeInfo); return SHAPELIST(outputShapeInfo);
} else { }
// this is basically permute mode
auto shapeList = SHAPELIST();
auto arguments = block.getIArguments();
if (shape::rank(inputShape->at(0)) == 0) {
Nd4jLong *newshape;
ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape->at(0)), Nd4jLong);
newshape[0] = 0;
newshape[1] = 0;
newshape[2] = 1;
newshape[3] = 99;
ArrayOptions::copyDataType(newshape, inputShape->at(0));
shapeList->push_back(newshape);
} else if (arguments->size() > 0 || inputShape->size() > 1) {
auto axis = arguments->size() > 0 ? *arguments : (INPUT_VARIABLE(1))->template asVectorT<int>();
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(axis.data(), axis.size(), *INPUT_VARIABLE(0), block.workspace());
shapeList->push_back(outputShapeInfo);
} else if (inputShape->size() == 2) {
// dead end
auto axis = INPUT_VARIABLE(1);
auto axisV = axis->template asVectorT<Nd4jLong>();
auto newshape = ShapeUtils::evalPermShapeInfo(axisV.data(), axisV.size(), *INPUT_VARIABLE(0), block.workspace());
shapeList->push_back(newshape);
} else {
int rank = shape::rank(inputShape->at(0));
for (int e = rank - 1; e >= 0; e--)
arguments->emplace_back(e);
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace());
shapeList->push_back(outputShapeInfo);
}
return shapeList;
}
}
} }
} }

Some files were not shown because too many files have changed in this diff Show More