Merge pull request #8835 from KonduitAI/master

Merge recent development work
master
Alex Black 2020-04-14 22:35:49 +10:00 committed by GitHub
commit 09f4c21059
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
475 changed files with 28133 additions and 12409 deletions

View File

@ -96,6 +96,11 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testLocalExecutionDataSources() throws Exception { public void testLocalExecutionDataSources() throws Exception {
@ -204,7 +209,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(dataProvider) .candidateGenerator(candidateGenerator).dataProvider(dataProvider)
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true)) .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), .terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS),
new MaxCandidatesCondition(3)) new MaxCandidatesCondition(3))
.build(); .build();
@ -251,7 +256,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
.candidateGenerator(candidateGenerator) .candidateGenerator(candidateGenerator)
.dataProvider(new TestMdsDataProvider(1, 32)) .dataProvider(new TestMdsDataProvider(1, 32))
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true)) .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), .terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS),
new MaxCandidatesCondition(3)) new MaxCandidatesCondition(3))
.scoreFunction(ScoreFunctions.testSetAccuracy()) .scoreFunction(ScoreFunctions.testSetAccuracy())
.build(); .build();

View File

@ -49,7 +49,7 @@ check_cuda_version "$VERSION"
case $VERSION in case $VERSION in
10.2) 10.2)
VERSION2="7.6" VERSION2="7.6"
VERSION3="1.5.2" VERSION3="1.5.3"
;; ;;
10.1) 10.1)
VERSION2="7.6" VERSION2="7.6"

View File

@ -16,6 +16,7 @@
package org.datavec.api.transform.split; package org.datavec.api.transform.split;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;

View File

@ -72,11 +72,11 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
protected File fullDir; protected File fullDir;
protected boolean useSubset = false; protected boolean useSubset = false;
InputSplit[] inputSplit; protected InputSplit[] inputSplit;
public static Map<String, String> lfwData = new HashMap<>(); public Map<String, String> lfwData = new HashMap<>();
public static Map<String, String> lfwLabel = new HashMap<>(); public Map<String, String> lfwLabel = new HashMap<>();
public static Map<String, String> lfwSubsetData = new HashMap<>(); public Map<String, String> lfwSubsetData = new HashMap<>();
public LFWLoader() { public LFWLoader() {
this(false); this(false);

View File

@ -45,15 +45,23 @@ import static org.junit.Assert.assertTrue;
*/ */
public class LoaderTests { public class LoaderTests {
private static void ensureDataAvailable(){
//Ensure test resources available by initializing CifarLoader and relying on auto download
boolean preProcessCifar = false;
int numExamples = 10;
int row = 28;
int col = 28;
int channels = 1;
for( boolean train : new boolean[]{true, false}){
CifarLoader loader = new CifarLoader(row, col, channels, train, preProcessCifar);
loader.next(numExamples);
}
new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42)).next();
}
@Test @Test
public void testLfwReader() throws Exception { public void testLfwReader() throws Exception {
String subDir = "lfw-a/lfw"; RecordReader rr = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42));
File path = new File(FilenameUtils.concat(System.getProperty("user.home"), subDir));
FileSplit fileSplit = new FileSplit(path, LFWLoader.ALLOWED_FORMATS, new Random(42));
BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(42), LFWLoader.LABEL_PATTERN, 1, 1, 1);
InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1);
RecordReader rr = new ImageRecordReader(250, 250, 3, LFWLoader.LABEL_PATTERN);
rr.initialize(inputSplit[0]);
List<String> exptedLabel = rr.getLabels(); List<String> exptedLabel = rr.getLabels();
RecordReader rr2 = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42)); RecordReader rr2 = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42));
@ -63,6 +71,7 @@ public class LoaderTests {
@Test @Test
public void testCifarLoader() { public void testCifarLoader() {
ensureDataAvailable();
File dir = new File(FilenameUtils.concat(System.getProperty("user.home"), "cifar/cifar-10-batches-bin")); File dir = new File(FilenameUtils.concat(System.getProperty("user.home"), "cifar/cifar-10-batches-bin"));
CifarLoader cifar = new CifarLoader(false, dir); CifarLoader cifar = new CifarLoader(false, dir);
assertTrue(dir.exists()); assertTrue(dir.exists());
@ -71,6 +80,7 @@ public class LoaderTests {
@Test @Test
public void testCifarInputStream() throws Exception { public void testCifarInputStream() throws Exception {
ensureDataAvailable();
// check train // check train
String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin"; String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin";
String path = FilenameUtils.concat(System.getProperty("user.home"), subDir); String path = FilenameUtils.concat(System.getProperty("user.home"), subDir);

View File

@ -23,6 +23,11 @@ import org.junit.Test;
public class TestDataSets extends BaseDL4JTest { public class TestDataSets extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L;
}
@Test @Test
public void testTinyImageNetExists() throws Exception { public void testTinyImageNetExists() throws Exception {
//Simple sanity check on extracting //Simple sanity check on extracting

View File

@ -130,14 +130,6 @@ public class SameDiffConv extends SameDiffLayer {
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
SDVariable[] vars;
if(hasBias){
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
vars = new SDVariable[]{layerInput, w, b};
} else {
vars = new SDVariable[]{layerInput, w};
}
Conv2DConfig c = Conv2DConfig.builder() Conv2DConfig c = Conv2DConfig.builder()
.kH(kernel[0]).kW(kernel[1]) .kH(kernel[0]).kW(kernel[1])
.pH(padding[0]).pW(padding[1]) .pH(padding[0]).pW(padding[1])
@ -146,7 +138,13 @@ public class SameDiffConv extends SameDiffLayer {
.isSameMode(this.cm == ConvolutionMode.Same) .isSameMode(this.cm == ConvolutionMode.Same)
.build(); .build();
SDVariable conv = sameDiff.cnn().conv2d(vars, c); //TODO can't set name SDVariable conv = null;
if(hasBias){
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
conv = sameDiff.cnn().conv2d(layerInput, w, b, c);
} else {
conv = sameDiff.cnn().conv2d(layerInput, w, c);
}
return activation.asSameDiff("out", sameDiff, conv); return activation.asSameDiff("out", sameDiff, conv);
} }

View File

@ -44,6 +44,11 @@ import static org.junit.Assert.*;
public class TestCheckpointListener extends BaseDL4JTest { public class TestCheckpointListener extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Rule @Rule
public TemporaryFolder tempDir = new TemporaryFolder(); public TemporaryFolder tempDir = new TemporaryFolder();
@ -57,7 +62,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
DataSetIterator iter = new IrisDataSetIterator(75,150); DataSetIterator iter = new IrisDataSetIterator(25,50);
return new Pair<>(net, iter); return new Pair<>(net, iter);
} }
@ -178,13 +183,13 @@ public class TestCheckpointListener extends BaseDL4JTest {
CheckpointListener l = new CheckpointListener.Builder(f) CheckpointListener l = new CheckpointListener.Builder(f)
.keepLast(3) .keepLast(3)
.saveEvery(3, TimeUnit.SECONDS) .saveEvery(4, TimeUnit.SECONDS)
.build(); .build();
net.setListeners(l); net.setListeners(l);
for(int i=0; i<5; i++ ){ //10 iterations total for(int i=0; i<5; i++ ){ //10 iterations total
net.fit(iter); net.fit(iter);
Thread.sleep(4000); Thread.sleep(5000);
} }
//Expect models saved at iterations: 2, 4, 6, 8 (iterations 0 and 1 shoud happen before first 3 seconds is up) //Expect models saved at iterations: 2, 4, 6, 8 (iterations 0 and 1 shoud happen before first 3 seconds is up)

View File

@ -54,6 +54,11 @@ import static org.junit.Assert.*;
@Slf4j @Slf4j
public class RegressionTest100a extends BaseDL4JTest { public class RegressionTest100a extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override @Override
public DataType getDataType(){ public DataType getDataType(){
return DataType.FLOAT; return DataType.FLOAT;

View File

@ -52,6 +52,11 @@ import static org.junit.Assert.*;
public class RegressionTest100b3 extends BaseDL4JTest { public class RegressionTest100b3 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override @Override
public DataType getDataType(){ public DataType getDataType(){
return DataType.FLOAT; return DataType.FLOAT;

View File

@ -69,6 +69,11 @@ import org.nd4j.resources.Resources;
public class RegressionTest100b4 extends BaseDL4JTest { public class RegressionTest100b4 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override @Override
public DataType getDataType() { public DataType getDataType() {
return DataType.FLOAT; return DataType.FLOAT;
@ -123,7 +128,8 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
assertEquals(dtype, net.params().dataType()); assertEquals(dtype, net.params().dataType());
assertEquals("Test for dtype: " + dtypeName, outExp, outAct); boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue("Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct, eq);
} }
} }

View File

@ -56,6 +56,11 @@ public class RegressionTest100b6 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void testCustomLayer() throws Exception { public void testCustomLayer() throws Exception {
@ -106,7 +111,8 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
assertEquals(dtype, net.params().dataType()); assertEquals(dtype, net.params().dataType());
boolean eq = outExp.equalsWithEps(outAct, 0.01); boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue(outExp + " vs " + outAct, eq); } assertTrue(outExp + " vs " + outAct, eq);
}
} }

View File

@ -28,7 +28,7 @@
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml --> <!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
<cuda.version>10.2</cuda.version> <cuda.version>10.2</cuda.version>
<cudnn.version>7.6</cudnn.version> <cudnn.version>7.6</cudnn.version>
<javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version> <javacpp-presets.cuda.version>1.5.3</javacpp-presets.cuda.version>
</properties> </properties>
<dependencyManagement> <dependencyManagement>

View File

@ -96,6 +96,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
} }
}; };
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test(expected = IllegalStateException.class) @Test(expected = IllegalStateException.class)
public void fileNotFoundEndToEnd() throws Exception { public void fileNotFoundEndToEnd() throws Exception {
String modelPath = "modelimport/keras/examples/foo/bar.h5"; String modelPath = "modelimport/keras/examples/foo/bar.h5";

View File

@ -25,7 +25,7 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
*/ */
@Deprecated @Deprecated
@Getter @Getter
@EqualsAndHashCode @EqualsAndHashCode(callSuper = true)
public class EvaluationCalibration extends org.nd4j.evaluation.classification.EvaluationCalibration implements org.deeplearning4j.eval.IEvaluation<org.nd4j.evaluation.classification.EvaluationCalibration> { public class EvaluationCalibration extends org.nd4j.evaluation.classification.EvaluationCalibration implements org.deeplearning4j.eval.IEvaluation<org.nd4j.evaluation.classification.EvaluationCalibration> {
/** /**

View File

@ -31,6 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.Map; import java.util.Map;
@ -99,15 +100,15 @@ public class CapsuleLayer extends SameDiffLayer {
} }
@Override @Override
public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) { public SDVariable defineLayer(SameDiff sd, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
// input: [mb, inputCapsules, inputCapsuleDimensions] // input: [mb, inputCapsules, inputCapsuleDimensions]
// [mb, inputCapsules, 1, inputCapsuleDimensions, 1] // [mb, inputCapsules, 1, inputCapsuleDimensions, 1]
SDVariable expanded = SD.expandDims(SD.expandDims(input, 2), 4); SDVariable expanded = sd.expandDims(sd.expandDims(input, 2), 4);
// [mb, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions, 1] // [mb, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions, 1]
SDVariable tiled = SD.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1); SDVariable tiled = sd.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1);
// [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions] // [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions]
SDVariable weights = paramTable.get(WEIGHT_PARAM); SDVariable weights = paramTable.get(WEIGHT_PARAM);
@ -119,13 +120,13 @@ public class CapsuleLayer extends SameDiffLayer {
// b is the logits of the routing procedure // b is the logits of the routing procedure
// [mb, inputCapsules, capsules, 1, 1] // [mb, inputCapsules, capsules, 1, 1]
SDVariable b = SD.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1)); SDVariable b = sd.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1));
for(int i = 0 ; i < routings ; i++){ for(int i = 0 ; i < routings ; i++){
// c is the coupling coefficient, i.e. the edge weight between the 2 capsules // c is the coupling coefficient, i.e. the edge weight between the 2 capsules
// [mb, inputCapsules, capsules, 1, 1] // [mb, inputCapsules, capsules, 1, 1]
SDVariable c = CapsuleUtils.softmax(SD, b, 2, 5); SDVariable c = sd.nn.softmax(b, 2);
// [mb, 1, capsules, capsuleDimensions, 1] // [mb, 1, capsules, capsuleDimensions, 1]
SDVariable s = c.times(uHat).sum(true, 1); SDVariable s = c.times(uHat).sum(true, 1);
@ -135,14 +136,14 @@ public class CapsuleLayer extends SameDiffLayer {
// v is the per capsule activations. On the last routing iteration, this is output // v is the per capsule activations. On the last routing iteration, this is output
// [mb, 1, capsules, capsuleDimensions, 1] // [mb, 1, capsules, capsuleDimensions, 1]
SDVariable v = CapsuleUtils.squash(SD, s, 3); SDVariable v = CapsuleUtils.squash(sd, s, 3);
if(i == routings - 1){ if(i == routings - 1){
return SD.squeeze(SD.squeeze(v, 1), 3); return sd.squeeze(sd.squeeze(v, 1), 3);
} }
// [mb, inputCapsules, capsules, capsuleDimensions, 1] // [mb, inputCapsules, capsules, capsuleDimensions, 1]
SDVariable vTiled = SD.tile(v, 1, (int) inputCapsules, 1, 1, 1); SDVariable vTiled = sd.tile(v, 1, (int) inputCapsules, 1, 1, 1);
// [mb, inputCapsules, capsules, 1, 1] // [mb, inputCapsules, capsules, 1, 1]
b = b.plus(uHat.times(vTiled).sum(true, 3)); b = b.plus(uHat.times(vTiled).sum(true, 3));

View File

@ -178,9 +178,11 @@ public class LocallyConnected1D extends SameDiffLayer {
//Note: for same mode, bottom/right padding can be 1 more than top/left padding //Note: for same mode, bottom/right padding can be 1 more than top/left padding
//NCW format. //NCW format.
if(cm == ConvolutionMode.Same) { if(cm == ConvolutionMode.Same) {
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, paddingR}}, 0); layerInput = sameDiff.nn().pad(layerInput,
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), 0);
} else { } else {
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, padding}}, 0); layerInput = sameDiff.nn().pad(layerInput,
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), 0);
} }
} }

View File

@ -184,9 +184,11 @@ public class LocallyConnected2D extends SameDiffLayer {
//Note: for same mode, bottom/right padding can be 1 more than top/left padding //Note: for same mode, bottom/right padding can be 1 more than top/left padding
//NCHW format //NCHW format
if(cm == ConvolutionMode.Same){ if(cm == ConvolutionMode.Same){
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}}, 0); layerInput = sameDiff.nn().pad(layerInput,
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), 0.0);
} else { } else {
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}}, 0); layerInput = sameDiff.nn().pad(layerInput,
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), 0.0);
} }
} }

View File

@ -185,7 +185,9 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
final val R = paramTable.get(RECURRENT_WEIGHT_KEY); final val R = paramTable.get(RECURRENT_WEIGHT_KEY);
final val b = paramTable.get(BIAS_KEY); final val b = paramTable.get(BIAS_KEY);
SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2); long[] shape = layerInput.getShape();
Preconditions.checkState(shape != null, "Null shape for input placeholder");
SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2, (int)shape[2]);
this.timeSteps = inputSlices.length; this.timeSteps = inputSlices.length;
SDVariable[] outputSlices = new SDVariable[timeSteps]; SDVariable[] outputSlices = new SDVariable[timeSteps];
SDVariable prev = null; SDVariable prev = null;

View File

@ -72,7 +72,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
INDArray out = activateHelper(last, training, false, workspaceMgr).getFirst(); INDArray out = activateHelper(last, training, false, workspaceMgr).getFirst();
if(storeLastForTBPTT){ if(storeLastForTBPTT){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1))); tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)).dup());
} }
} }
return out; return out;

View File

@ -45,15 +45,4 @@ public class CapsuleUtils {
return x.times(squaredNorm).div(squaredNorm.plus(1.0).times(scale)); return x.times(squaredNorm).div(squaredNorm.plus(1.0).times(scale));
} }
/**
* Compute softmax along a given dimension
*/
public static SDVariable softmax(SameDiff SD, SDVariable x, int dimension, int rank){
int[] permutation = ArrayUtil.range(0, rank);
permutation[0] = dimension;
permutation[dimension] = 0;
return SD.nn.softmax(x.permute(permutation)).permute(ArrayUtil.invertPermutation(permutation));
}
} }

View File

@ -84,6 +84,13 @@
<profiles> <profiles>
<profile>
<id>testresources</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
</profile>
<profile> <profile>
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
<activation> <activation>

View File

@ -19,6 +19,7 @@ import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer; import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer; import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.remote.clients.serde.impl.IntegerSerde; import org.nd4j.remote.clients.serde.impl.IntegerSerde;
import org.nd4j.resources.Resources;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import javax.imageio.ImageIO; import javax.imageio.ImageIO;
@ -65,7 +66,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
@Test @Test
public void testMlnMnist_ImageInput() throws Exception { public void testMlnMnist_ImageInput() throws Exception {
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile(); val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net) val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
@ -129,7 +130,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
@Test @Test
public void testMlnMnist_ImageInput_Async() throws Exception { public void testMlnMnist_ImageInput_Async() throws Exception {
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile(); val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net) val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
@ -198,7 +199,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
@Test @Test
public void testBinaryIn_BinaryOut() throws Exception { public void testBinaryIn_BinaryOut() throws Exception {
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile(); val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net) val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)

View File

@ -495,7 +495,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28);
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10)); SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10));
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10)); SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10));
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b)); SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b), -1);
val server = new JsonModelServer.Builder<float[], Integer>(sd) val server = new JsonModelServer.Builder<float[], Integer>(sd)
.outputSerializer( new IntSerde()) .outputSerializer( new IntSerde())

View File

@ -20,6 +20,9 @@
<name>deeplearning4j-remote</name> <name>deeplearning4j-remote</name>
<profiles> <profiles>
<profile>
<id>testresources</id>
</profile>
<profile> <profile>
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>

View File

@ -58,7 +58,7 @@ public class TestSameDiffUI extends BaseDL4JTest {
SDVariable b = sd.var("b", DataType.FLOAT, 1, 4); SDVariable b = sd.var("b", DataType.FLOAT, 1, 4);
SDVariable z = in.mmul(w).add(b); SDVariable z = in.mmul(w).add(b);
SDVariable a = sd.nn().tanh(z); SDVariable a = sd.math().tanh(z);
LogFileWriter lfw = new LogFileWriter(f); LogFileWriter lfw = new LogFileWriter(f);
lfw.writeGraphStructure(sd); lfw.writeGraphStructure(sd);

View File

@ -20,7 +20,10 @@ package org.deeplearning4j.integration;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.integration.testcases.dl4j.*;
import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases;
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
import org.deeplearning4j.integration.testcases.samediff.SameDiffRNNTestCases;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -66,7 +69,29 @@ public class IntegrationTestBaselineGenerator {
} }
runGeneration( runGeneration(
SameDiffMLPTestCases.getMLPMnist()
// DL4J integration test cases.
// CNN1DTestCases.getCnn1dTestCaseCharRNN(),
// CNN2DTestCases.testLenetTransferDropoutRepeatability(),
//// CNN2DTestCases.getCnn2DSynthetic(),
// CNN2DTestCases.getLenetMnist(),
// CNN2DTestCases.getVGG16TransferTinyImagenet(),
// CNN2DTestCases.getYoloHouseNumbers(),
// CNN3DTestCases.getCnn3dTestCaseSynthetic(),
// MLPTestCases.getMLPMnist(),
// MLPTestCases.getMLPMoon(),
// RNNTestCases.getRnnCharacterTestCase(),
// RNNTestCases.getRnnCsvSequenceClassificationTestCase1(),
// RNNTestCases.getRnnCsvSequenceClassificationTestCase2(),
// UnsupervisedTestCases.getVAEMnistAnomaly(),
// Samediff test cases done
SameDiffMLPTestCases.getMLPMnist(),
SameDiffMLPTestCases.getMLPMoon(),
SameDiffCNNCases.getLenetMnist(),
SameDiffCNNCases.getCnn3dSynthetic(),
SameDiffRNNTestCases.getRnnCsvSequenceClassificationTestCase1()
); );
} }
@ -331,7 +356,6 @@ public class IntegrationTestBaselineGenerator {
} }
} }
if (tc.isTestEvaluation()) { if (tc.isTestEvaluation()) {
IEvaluation[] evals = tc.getNewEvaluations(); IEvaluation[] evals = tc.getNewEvaluations();
MultiDataSetIterator iter = tc.getEvaluationTestData(); MultiDataSetIterator iter = tc.getEvaluationTestData();

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.integration; package org.deeplearning4j.integration;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases;
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
@ -37,4 +38,20 @@ public class IntegrationTestsSameDiff extends BaseDL4JTest {
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir); IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir);
} }
@Test
public void testMLPMoon() throws Exception {
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMoon(), testDir);
}
@Test
public void testLenetMnist() throws Exception {
IntegrationTestRunner.runTest(SameDiffCNNCases.getLenetMnist(), testDir);
}
@Test
public void testCnn3dSynthetic() throws Exception {
IntegrationTestRunner.runTest(SameDiffCNNCases.getCnn3dSynthetic(), testDir);
}
} }

View File

@ -194,6 +194,8 @@ public class CNN2DTestCases {
testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb) testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb)
testEvaluation = false; testEvaluation = false;
testOverfitting = false; testOverfitting = false;
maxRelativeErrorOutput = 0.2;
minAbsErrorOutput = 0.05; //Max value is around 0.22
} }
@Override @Override
@ -314,6 +316,7 @@ public class CNN2DTestCases {
ComputationGraph model = new TransferLearning.GraphBuilder(pretrained) ComputationGraph model = new TransferLearning.GraphBuilder(pretrained)
.fineTuneConfiguration(fineTuneConf) .fineTuneConfiguration(fineTuneConf)
.removeVertexKeepConnections("conv2d_9") .removeVertexKeepConnections("conv2d_9")
.removeVertexAndConnections("outputs")
.addLayer("convolution2d_9", .addLayer("convolution2d_9",
new ConvolutionLayer.Builder(1,1) new ConvolutionLayer.Builder(1,1)
.nIn(1024) .nIn(1024)
@ -393,7 +396,7 @@ public class CNN2DTestCases {
@Override @Override
public ModelType modelType() { public ModelType modelType() {
return ModelType.CG; return ModelType.MLN;
} }
@Override @Override

View File

@ -77,6 +77,10 @@ public class MLPTestCases {
testOverfitting = true; testOverfitting = true;
maxRelativeErrorOverfit = 2e-2; maxRelativeErrorOverfit = 2e-2;
minAbsErrorOverfit = 1e-2; minAbsErrorOverfit = 1e-2;
maxRelativeErrorGradients = 0.01;
minAbsErrorGradients = 0.05;
maxRelativeErrorParamsPostTraining = 0.01;
minAbsErrorParamsPostTraining = 0.05;
} }
@Override @Override
@ -135,8 +139,7 @@ public class MLPTestCases {
public IEvaluation[] getNewEvaluations(){ public IEvaluation[] getNewEvaluations(){
return new IEvaluation[]{ return new IEvaluation[]{
new Evaluation(), new Evaluation(),
new ROCMultiClass(), new ROCMultiClass()
new EvaluationCalibration()
}; };
} }

View File

@ -24,6 +24,7 @@ import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor; import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator; import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
@ -91,7 +92,7 @@ public class RNNTestCases {
} }
private int miniBatchSize = 32; private int miniBatchSize = 32;
private int exampleLength = 1000; private int exampleLength = 200;
@Override @Override
@ -101,6 +102,7 @@ public class RNNTestCases {
@Override @Override
public Object getConfiguration() throws Exception { public Object getConfiguration() throws Exception {
Nd4j.getRandom().setSeed(12345);
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength); CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
int nOut = iter.totalOutcomes(); int nOut = iter.totalOutcomes();
@ -113,7 +115,7 @@ public class RNNTestCases {
.seed(12345) .seed(12345)
.l2(0.001) .l2(0.001)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new RmsProp(0.1)) .updater(new Adam(1e-3))
.list() .list()
.layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize) .layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
.activation(Activation.TANH).build()) .activation(Activation.TANH).build())
@ -140,7 +142,7 @@ public class RNNTestCases {
@Override @Override
public MultiDataSetIterator getTrainingData() throws Exception { public MultiDataSetIterator getTrainingData() throws Exception {
DataSetIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength); DataSetIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
iter = new EarlyTerminationDataSetIterator(iter, 2); //3 minibatches, 1000/200 = 5 updates per minibatch iter = new EarlyTerminationDataSetIterator(iter, 2); //2 minibatches, 200/50 = 4 updates per minibatch
return new MultiDataSetIteratorAdapter(iter); return new MultiDataSetIteratorAdapter(iter);
} }

View File

@ -72,12 +72,12 @@ public class UnsupervisedTestCases {
return new NeuralNetConfiguration.Builder() return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT) .dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.updater(new Adam(0.05)) .updater(new Adam(1e-3))
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.l2(1e-4) .l2(1e-4)
.list() .list()
.layer(0, new VariationalAutoencoder.Builder() .layer(0, new VariationalAutoencoder.Builder()
.activation(Activation.LEAKYRELU) .activation(Activation.TANH)
.encoderLayerSizes(256, 256) //2 encoder layers, each of size 256 .encoderLayerSizes(256, 256) //2 encoder layers, each of size 256
.decoderLayerSizes(256, 256) //2 decoder layers, each of size 256 .decoderLayerSizes(256, 256) //2 decoder layers, each of size 256
.pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function .pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function

View File

@ -0,0 +1,398 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases.samediff;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import java.util.*;
public class SameDiffCNNCases {
public static TestCase getLenetMnist() {
return new TestCase() {
{
testName = "LenetMnistSD";
testType = TestType.RANDOM_INIT;
testPredictions = true;
testTrainingCurves = true;
testGradients = true;
testParamsPostTraining = true;
testEvaluation = true;
testOverfitting = false;
}
@Override
public ModelType modelType() {
return ModelType.SAMEDIFF;
}
public Object getConfiguration() throws Exception {
Nd4j.getRandom().setSeed(12345);
int nChannels = 1; // Number of input channels
int outputNum = 10; // The number of possible outcomes
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, outputNum);
//input [minibatch, channels=1, Height = 28, Width = 28]
SDVariable in4d = in.reshape(-1, nChannels, 28, 28);
int kernelHeight = 5;
int kernelWidth = 5;
// w0 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 1, outputChannels = 20]
// b0 [20]
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, nChannels, 20).muli(0.01));
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 20).muli(0.01));
SDVariable layer0 = sd.nn.relu(sd.cnn.conv2d("layer0", in4d, w0, b0, Conv2DConfig.builder()
.kH(kernelHeight)
.kW(kernelWidth)
.sH(1)
.sW(1)
.dataFormat("NCHW")
.build()), 0);
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W) = ( 28 - 5 + 2*0 ) / 1 + 1 = 24
// [minibatch,20,24,24]
SDVariable layer1 = sd.cnn.maxPooling2d("layer1", layer0, Pooling2DConfig.builder()
.kH(2).kW(2)
.sH(2).sW(2)
.isNHWC(false)
.build());
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W) = ( 24 - 2 + 2*0 ) / 2 + 1 = 12
// [minibatch,12,12,20]
// w2 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 20, outputChannels = 50]
// b0 [50]
SDVariable w2 = sd.var("w2", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, 20, 50).muli(0.01));
SDVariable b2 = sd.var("b2", Nd4j.rand(DataType.FLOAT, 50).muli(0.01));
SDVariable layer2 = sd.nn.relu(sd.cnn.conv2d("layer2", layer1, w2, b2, Conv2DConfig.builder()
.kH(kernelHeight)
.kW(kernelWidth)
.sH(1)
.sW(1)
.dataFormat("NCHW")
.build()), 0);
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W) = ( 12 - 5 + 2*0 ) / 1 + 1 = 8
// [minibatch,8,8,50]
SDVariable layer3 = sd.cnn.maxPooling2d("layer3", layer2, Pooling2DConfig.builder()
.kH(2).kW(2)
.sH(2).sW(2)
.isNHWC(false)
.build());
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W) = ( 8 - 2 + 2*0 ) / 2 + 1 = 4
// [minibatch,4,4,50]
int channels_height_width = 4 * 4 * 50;
SDVariable layer3_reshaped = layer3.reshape(-1, channels_height_width);
SDVariable w4 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width, 500).muli(0.01));
SDVariable b4 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 500).muli(0.01));
SDVariable layer4 = sd.nn.relu("layer4", layer3_reshaped.mmul(w4).add(b4), 0);
SDVariable w5 = sd.var("w5", Nd4j.rand(DataType.FLOAT, 500, outputNum));
SDVariable b5 = sd.var("b5", Nd4j.rand(DataType.FLOAT, outputNum));
SDVariable out = sd.nn.softmax("out", layer4.mmul(w5).add(b5));
SDVariable loss = sd.loss.logLoss("loss", label, out);
//Also set the training configuration:
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Adam(1e-3))
.l2(1e-3)
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
.build());
return sd;
}
@Override
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
Map<String, INDArray> map = new HashMap<>();
map.put("in", ds.getFeatures());
map.put("label", ds.getLabels());
return map;
}
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
iter = new EarlyTerminationDataSetIterator(iter, 60);
return new MultiDataSetIteratorAdapter(iter);
}
@Override
public MultiDataSetIterator getEvaluationTestData() throws Exception {
return new MultiDataSetIteratorAdapter(new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10));
}
@Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(8, true, 12345);
List<Map<String, INDArray>> list = new ArrayList<>();
org.nd4j.linalg.dataset.DataSet ds = iter.next();
ds = ds.asList().get(0);
list.add(Collections.singletonMap("in", ds.getFeatures()));
ds = iter.next();
list.add(Collections.singletonMap("in", ds.getFeatures()));
return list;
}
@Override
public List<String> getPredictionsNamesSameDiff() {
return Collections.singletonList("out");
}
@Override
public IEvaluation[] getNewEvaluations() {
return new IEvaluation[]{
new Evaluation(),
new ROCMultiClass(),
new EvaluationCalibration()};
}
@Override
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
sd.evaluate(iter, "out", 0, evaluations);
return evaluations;
}
};
}
public static TestCase getCnn3dSynthetic() {
return new TestCase() {
{
testName = "Cnn3dSynthetic";
testType = TestType.RANDOM_INIT;
testPredictions = true;
testTrainingCurves = true;
testGradients = true;
testParamsPostTraining = true;
testEvaluation = true;
testOverfitting = false;
}
@Override
public ModelType modelType() {
return ModelType.SAMEDIFF;
}
public Object getConfiguration() throws Exception {
Nd4j.getRandom().setSeed(12345);
int nChannels = 3; // Number of input channels
int outputNum = 10; // The number of possible outcomes
SameDiff sd = SameDiff.create();
//input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8]
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, nChannels, 8, 8, 8);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, nChannels, outputNum);
//input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8]
// Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]
// [kernelDepth = 3, kernelHeight = 3, kernelWidth = 3, inputChannels = 3, outputChannels = 8]
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 3, 3, 3, nChannels, 8));
// Optional 1D bias array with shape [outputChannels]. May be null.
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 8));
SDVariable layer0 = sd.nn.relu(sd.cnn.conv3d("layer0", in, w0, b0, Conv3DConfig.builder()
.kH(3)
.kW(3)
.kD(3)
.sH(2)
.sW(2)
.sD(2)
.dataFormat("NCDHW")
.build()), 0);
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W)(D) = (8 - 3 + 2*0 ) / 2 + 1 = 3
// [minibatch,8,3,3,3]
SDVariable layer1 = sd.cnn.maxPooling3d("layer1", layer0, Pooling3DConfig.builder()
.kH(2).kW(2).kD(2)
.sH(2).sW(2).sD(2)
.isNCDHW(true)
.build());
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W)(D) = ( 3 - 2 + 2*0 ) / 2 + 1 = 1
// [minibatch,8,1,1,1]
int channels_height_width_depth = 8 * 1 * 1 * 1;
SDVariable layer1_reshaped = layer1.reshape(-1, channels_height_width_depth);
SDVariable w1 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width_depth, 10));
SDVariable b1 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 10));
SDVariable out = sd.nn.softmax("out", layer1_reshaped.mmul(w1).add(b1));
SDVariable loss = sd.loss.logLoss("loss", label, out);
//Also set the training configuration:
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Nesterovs(0.01, 0.9))
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
.build());
return sd;
}
@Override
public Map<String,INDArray> getGradientsTestDataSameDiff() throws Exception {
Nd4j.getRandom().setSeed(12345);
//NCDHW format
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10);
Map<String, INDArray> map = new HashMap<>();
map.put("in", arr);
map.put("label", labels);
return map;
}
@Override
public List<String> getPredictionsNamesSameDiff() {
return Collections.singletonList("out");
}
@Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
Nd4j.getRandom().setSeed(12345);
List<Map<String, INDArray>> list = new ArrayList<>();
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
list.add(Collections.singletonMap("in", arr));
return list;
}
@Override
public MultiDataSet getGradientsTestData() throws Exception {
Nd4j.getRandom().setSeed(12345);
//NCDHW format
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10);
return new org.nd4j.linalg.dataset.MultiDataSet(arr, labels);
}
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
return new SingletonMultiDataSetIterator(getGradientsTestData());
}
@Override
public MultiDataSetIterator getEvaluationTestData() throws Exception {
return getTrainingData();
}
@Override
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){
sd.evaluate(iter, "out", 0, evaluations);
return evaluations;
}
@Override
public IEvaluation[] getNewEvaluations(){
return new IEvaluation[]{new Evaluation()};
}
};
}
}

View File

@ -15,27 +15,46 @@
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.integration.testcases.samediff; package org.deeplearning4j.integration.testcases.samediff;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.integration.ModelType; import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.*; import java.util.*;
import static org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig.*;
public class SameDiffMLPTestCases { public class SameDiffMLPTestCases {
@ -68,10 +87,10 @@ public class SameDiffMLPTestCases {
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10); SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256)); SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256).muli(0.1));
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256)); SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256).muli(0.1));
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10)); SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10).muli(0.1));
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10)); SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10).muli(0.1));
SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0)); SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0));
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1)); SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
@ -152,4 +171,160 @@ public class SameDiffMLPTestCases {
}; };
} }
public static TestCase getMLPMoon() {
return new TestCase() {
{
testName = "MLPMoonSD";
testType = TestType.RANDOM_INIT;
testPredictions = true;
testTrainingCurves = true;
testGradients = true;
testParamsPostTraining = true;
testEvaluation = true;
testOverfitting = true;
maxRelativeErrorOverfit = 2e-2;
minAbsErrorOverfit = 1e-2;
} }
@Override
public ModelType modelType() {
return ModelType.SAMEDIFF;
}
@Override
public Object getConfiguration() throws Exception {
int numInputs = 2;
int numOutputs = 2;
int numHiddenNodes = 20;
double learningRate = 0.005;
Nd4j.getRandom().setSeed(12345);
//Define the network structure:
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, numInputs);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, numOutputs);
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, numInputs, numHiddenNodes));
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, numHiddenNodes));
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numHiddenNodes, numOutputs));
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numOutputs));
SDVariable a0 = sd.nn.relu(in.mmul(w0).add(b0), 0);
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
SDVariable loss = sd.loss.logLoss("loss", label, out);
//Also set the training configuration:
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Nesterovs(learningRate, 0.9))
.weightDecay(1e-3, true)
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
.build());
return sd;
}
@Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
List<Map<String, INDArray>> out = new ArrayList<>();
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(f));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 0, 2);
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
return out;
}
@Override
public List<String> getPredictionsNamesSameDiff() throws Exception {
return Collections.singletonList("out");
}
@Override
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(f));
org.nd4j.linalg.dataset.DataSet ds = new RecordReaderDataSetIterator(rr, 5, 0, 2).next();
Map<String, INDArray> map = new HashMap<>();
map.put("in", ds.getFeatures());
map.put("label", ds.getLabels());
return map;
}
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_train.csv");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(f));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2);
iter = new EarlyTerminationDataSetIterator(iter, 32);
return new MultiDataSetIteratorAdapter(iter);
}
@Override
public IEvaluation[] getNewEvaluations() {
return new IEvaluation[]{
new Evaluation(),
new ROCMultiClass(),
new EvaluationCalibration()};
}
@Override
public MultiDataSetIterator getEvaluationTestData() throws Exception {
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(f));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2);
return new MultiDataSetIteratorAdapter(iter);
}
@Override
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
sd.evaluate(iter, "out", 0, evaluations);
return evaluations;
}
@Override
public MultiDataSet getOverfittingData() throws Exception {
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(f));
return new RecordReaderDataSetIterator(rr, 1, 0, 2).next().toMultiDataSet();
}
@Override
public int getOverfitNumIterations() {
return 200;
}
};
}
}

View File

@ -0,0 +1,289 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases.samediff;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.resources.Resources;
import org.nd4j.shade.guava.io.Files;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class SameDiffRNNTestCases {
public static TestCase getRnnCsvSequenceClassificationTestCase1() {
return new SameDiffRNNTestCases.RnnCsvSequenceClassificationTestCase1();
}
protected static class RnnCsvSequenceClassificationTestCase1 extends TestCase {
protected RnnCsvSequenceClassificationTestCase1() {
testName = "RnnCsvSequenceClassification1";
testType = TestType.RANDOM_INIT;
testPredictions = true;
testTrainingCurves = false;
testGradients = false;
testParamsPostTraining = false;
testEvaluation = true;
testOverfitting = false; //Not much point on this one - it already fits very well...
}
protected MultiDataNormalization normalizer;
protected MultiDataNormalization getNormalizer() throws Exception {
if (normalizer != null) {
return normalizer;
}
normalizer = new MultiNormalizerStandardize();
normalizer.fit(getTrainingDataUnnormalized());
return normalizer;
}
@Override
public ModelType modelType() {
return ModelType.SAMEDIFF;
}
@Override
public Object getConfiguration() throws Exception {
Nd4j.getRandom().setSeed(12345);
int miniBatchSize = 10;
int numLabelClasses = 6;
int nIn = 60;
int numUnits = 7;
int timeSteps = 3;
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, miniBatchSize, timeSteps, nIn);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, miniBatchSize, numLabelClasses);
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits));
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits));
LSTMLayerConfig c = LSTMLayerConfig.builder()
.lstmdataformat(LSTMDataFormat.NTS)
.directionMode(LSTMDirectionMode.FWD)
.gateAct(LSTMActivations.SIGMOID)
.cellAct(LSTMActivations.TANH)
.outAct(LSTMActivations.TANH)
.retFullSequence(true)
.retLastC(true)
.retLastH(true)
.build();
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
in, cLast, yLast, null,
LSTMLayerWeights.builder()
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
.peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits)))
.bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits)))
.build(),
c), c);
// Behaviour with default settings: 3d (time series) input with shape
// [miniBatchSize, vectorSize, timeSeriesLength] -> 2d output [miniBatchSize, vectorSize]
SDVariable layer0 = outputs.getOutput();
SDVariable layer1 = layer0.mean(1);
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numUnits, numLabelClasses));
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numLabelClasses));
SDVariable out = sd.nn.softmax("out", layer1.mmul(w1).add(b1));
SDVariable loss = sd.loss.logLoss("loss", label, out);
//Also set the training configuration:
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Adam(5e-2))
.l1(1e-3).l2(1e-3)
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
.build());
return sd;
}
@Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
MultiDataSet mds = getTrainingData().next();
List<Map<String, INDArray>> list = new ArrayList<>();
list.add(Collections.singletonMap("in", mds.getFeatures()[0].reshape(10, 1, 60)));
//[batchsize, insize]
return list;
}
@Override
public List<String> getPredictionsNamesSameDiff() throws Exception {
return Collections.singletonList("out");
}
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
MultiDataSetIterator iter = getTrainingDataUnnormalized();
MultiDataSetPreProcessor pp = multiDataSet -> {
INDArray l = multiDataSet.getLabels(0);
l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
multiDataSet.setLabels(0, l);
multiDataSet.setLabelsMaskArray(0, null);
};
iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));
return iter;
}
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
int miniBatchSize = 10;
int numLabelClasses = 6;
File featuresDirTrain = Files.createTempDir();
File labelsDirTrain = Files.createTempDir();
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain);
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain);
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);
return iter;
}
@Override
public IEvaluation[] getNewEvaluations() {
return new IEvaluation[]{
new Evaluation(),
new ROCMultiClass(),
new EvaluationCalibration()
};
}
@Override
public MultiDataSetIterator getEvaluationTestData() throws Exception {
int miniBatchSize = 10;
int numLabelClasses = 6;
// File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
// File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
File featuresDirTest = Files.createTempDir();
File labelsDirTest = Files.createTempDir();
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest);
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest);
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);
MultiDataSetPreProcessor pp = multiDataSet -> {
INDArray l = multiDataSet.getLabels(0);
l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
multiDataSet.setLabels(0, l);
multiDataSet.setLabelsMaskArray(0, null);
};
iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));
return iter;
}
@Override
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
sd.evaluate(iter, "out", 0, evaluations);
return evaluations;
}
}
}

View File

@ -40,21 +40,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2' implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3'
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2' implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3'
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2' implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3'
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64"
implementation 'com.google.code.gson:gson:2.8.2' implementation 'com.google.code.gson:gson:2.8.2'
annotationProcessor 'org.projectlombok:lombok:1.16.16' annotationProcessor 'org.projectlombok:lombok:1.16.16'

View File

@ -35,21 +35,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2' implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3'
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2' implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3'
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2' implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3'
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64"
``` ```
Compiling these dependencies involves a large number of files, thus it is necessary to set multiDexEnabled to true in defaultConfig. Compiling these dependencies involves a large number of files, thus it is necessary to set multiDexEnabled to true in defaultConfig.

View File

@ -43,21 +43,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2' implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3'
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2' implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3'
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2' implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3'
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64"
testimplementation 'junit:junit:4.12' testimplementation 'junit:junit:4.12'
``` ```

View File

@ -46,21 +46,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2' implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3'
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2' implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3'
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2' implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3'
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64"
``` ```

View File

@ -62,7 +62,7 @@ Alternatively, in the case of CUDA 10.2, cuDNN comes bundled with the "redist" p
<dependency> <dependency>
<groupId>org.bytedeco</groupId> <groupId>org.bytedeco</groupId>
<artifactId>cuda-platform-redist</artifactId> <artifactId>cuda-platform-redist</artifactId>
<version>10.2-7.6-1.5.2</version> <version>10.2-7.6-1.5.3</version>
</dependency> </dependency>
Also note that, by default, Deeplearning4j will use the fastest algorithms available according to cuDNN, but memory usage may be excessive, causing strange launch errors. When this happens, try to reduce memory usage by using the [`NO_WORKSPACE` mode settable via the network configuration](/api/{{page.version}}/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.Builder.html#cudnnAlgoMode-org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode-), instead of the default of `ConvolutionLayer.AlgoMode.PREFER_FASTEST`, for example: Also note that, by default, Deeplearning4j will use the fastest algorithms available according to cuDNN, but memory usage may be excessive, causing strange launch errors. When this happens, try to reduce memory usage by using the [`NO_WORKSPACE` mode settable via the network configuration](/api/{{page.version}}/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.Builder.html#cudnnAlgoMode-org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode-), instead of the default of `ConvolutionLayer.AlgoMode.PREFER_FASTEST`, for example:

View File

@ -5,7 +5,7 @@ project(mkldnn-download NONE)
include(ExternalProject) include(ExternalProject)
ExternalProject_Add(mkldnn ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
GIT_TAG v1.2.2 GIT_TAG v1.3
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""

View File

@ -403,7 +403,6 @@ NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, sd::Launch
///////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////
// u8 string constructors // u8 string constructors
/////////////////////////////////////////////////////////////////////////
NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) {
if (!DataTypeUtils::isS(dtype)) { if (!DataTypeUtils::isS(dtype)) {
@ -1944,7 +1943,7 @@ void NDArray::tilei(const std::vector<Nd4jLong>& reps) {
Nd4jLong NDArray::sizeAt(const int dim) const { Nd4jLong NDArray::sizeAt(const int dim) const {
if (dim >= this->rankOf() || dim < -this->rankOf()) if (dim >= this->rankOf() || dim < -this->rankOf())
throw std::runtime_error("Bad size index requested"); throw std::runtime_error("NDArray::sizeAt: bad size index requested");
if (dim >= 0) if (dim >= 0)
return shape::shapeOf(_shapeInfo)[dim]; return shape::shapeOf(_shapeInfo)[dim];

View File

@ -10,7 +10,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !"); throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); nd4j_printf("applyTriplewiseLambda requires all operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach"); throw std::runtime_error("Shapes mismach");
} }

View File

@ -23,6 +23,7 @@
#include <array/NDArrayList.h> #include <array/NDArrayList.h>
#include <helpers/ShapeUtils.h> #include <helpers/ShapeUtils.h>
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/stack.h>
namespace sd { namespace sd {
NDArrayList::NDArrayList(int height, bool expandable) { NDArrayList::NDArrayList(int height, bool expandable) {
@ -144,23 +145,36 @@ namespace sd {
NDArray* NDArrayList::stack() { NDArray* NDArrayList::stack() {
// FIXME: this is bad for perf, but ok as poc // FIXME: this is bad for perf, but ok as poc
sd::ops::stack op;
std::vector<NDArray*> inputs;
std::vector<double> targs;
std::vector<Nd4jLong> iargs({0});
std::vector<bool> bargs;
int numElements = _elements.load();
int numElements = _elements.load();
std::vector<const NDArray*> inputs(numElements);
for (int e = 0; e < numElements; e++) { for (int e = 0; e < numElements; e++) {
_chunks[e]->syncToDevice(); _chunks[e]->syncToDevice();
inputs.emplace_back(_chunks[e]); inputs[e] = _chunks[e];
} }
iargs.push_back(_axis); auto inShapeInfo = inputs[0]->getShapeInfo();
int rank = shape::rank(inShapeInfo);
NDArray* array = nullptr;
auto result = op.evaluate(inputs); if (shape::isEmpty(inShapeInfo)) {
switch (rank) {
case 0: {
if (numElements == 1) {
array = new NDArray(inputs[0]->ordering(), {0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext());
} else {
array = new NDArray('c', {(Nd4jLong) numElements, 0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext() ) ;
}
}
}
}
else{
std::vector<Nd4jLong> outShape(inShapeInfo + 1, inShapeInfo + 1 + rank);
outShape.insert(outShape.begin(), (Nd4jLong) numElements);
array = new NDArray( shape::order(inShapeInfo), outShape, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext());
}
auto array = new NDArray(result.at(0)->dup()); ops::helpers::stack(inputs[0]->getContext(), inputs, *array, 0);
return array; return array;
} }

View File

@ -47,13 +47,13 @@ class ND4J_EXPORT GradCheck {
* opBP - back propagation operation * opBP - back propagation operation
* argsHolderFF - argument holder for feed forward operation * argsHolderFF - argument holder for feed forward operation
* argsHolderBP - argument holder for back propagation operation * argsHolderBP - argument holder for back propagation operation
* whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty array which means to check all arrays * whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays
* IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.} * IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.}
* loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM * loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM
* outArrsFFIdx - contains indexes of ff output arrays which are independent from each other, default means all are independent
*/ */
static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
const std::vector<bool>& whatArrsToCheck = std::vector<bool>(), const std::vector<double>& IdxRange = {0., 1.}, const LossFunc loss = SUM); const std::vector<bool>& whatArrsToCheck = std::vector<bool>(), const std::vector<double>& IdxRange = {0., 1.}, const LossFunc loss = SUM, const std::vector<int>& outArrsFFIdx = {});
}; };

View File

@ -206,7 +206,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const
const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c'; const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c';
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';; const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';;
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && shape::rank(xShapeInfo) == 2 && xEws == 1 && xOrder == 'c' && xRank == 2 && if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 &&
tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
return SMALLARR2DX; return SMALLARR2DX;
if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))

View File

@ -60,7 +60,7 @@ namespace sd {
static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB); static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB);
#endif #endif
static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY); static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha = 1.0, double beta = 0.0);
}; };
} }

View File

@ -372,16 +372,16 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
int xLenDim(0), yLenDim(0); int xLenDim(0), yLenDim(0);
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) if(!shape::isCommonVector(X->getShapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); throw std::runtime_error("MmulHelper::dot: X array must be vector !");
if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim)) if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim))
throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); throw std::runtime_error("MmulHelper::dot: Y array must be vector !");
if(Z != nullptr && !Z->isScalar()) if(Z != nullptr && !Z->isScalar())
throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); throw std::runtime_error("MmulHelper::dot: Z array must be scalar !");
const auto length = X->lengthOf(); const auto length = X->lengthOf();
if(Y->lengthOf() != length) if(Y->lengthOf() != length)
throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); throw std::runtime_error("MmulHelper::dot: lengths of input vectors are different !");
if(Z == nullptr) if(Z == nullptr)
Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext());

View File

@ -49,7 +49,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss ) { const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss, const std::vector<int>& outArrsFFIdx) {
const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
@ -82,6 +82,16 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
int numOutArrs = outArrsFF.size(); int numOutArrs = outArrsFF.size();
double scorePlus = 0.; double scorePlus = 0.;
if(!outArrsFFIdx.empty()) {
for(const auto& k : outArrsFFIdx) { // loop through independent output arrays
if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0);
}
}
else {
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM) if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
@ -89,12 +99,23 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0); scorePlus += tmpScalar.e<double>(0);
} }
}
// subtract epsilon, feed forward // subtract epsilon, feed forward
inArrsFF[i]->p<double>(j, orig - EPSILON); inArrsFF[i]->p<double>(j, orig - EPSILON);
outArrsFF = opFF.execute(argsHolderFF); outArrsFF = opFF.execute(argsHolderFF);
double scoreMinus = 0.; double scoreMinus = 0.;
if(!outArrsFFIdx.empty()) {
for(const auto& k : outArrsFFIdx) { // loop through independent output arrays
if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0);
}
}
else {
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM) if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
@ -102,6 +123,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0); scoreMinus += tmpScalar.e<double>(0);
} }
}
// restore initial element value // restore initial element value
inArrsFF[i]->p<double>(j, orig); inArrsFF[i]->p<double>(j, orig);
@ -120,7 +142,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
throw std::runtime_error(""); throw std::runtime_error("");
} }
// printf("num = %.5f, ana = %.5f\n", numericalGrad, analyticGrad); // printf("%lld: num = %.15f, ana = %.15f\n", j, numericalGrad, analyticGrad);
// calculate relative error // calculate relative error
double relError; double relError;
@ -134,7 +156,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
if(math::nd4j_abs<double>(analyticGrad - numericalGrad) < MINABSERR) if(math::nd4j_abs<double>(analyticGrad - numericalGrad) < MINABSERR)
continue; continue;
printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad); printf("numericalGrad = %.15f, analyticGrad = %.15f \n", numericalGrad, analyticGrad);
printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j); printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j);
return false; return false;
} }

View File

@ -239,7 +239,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY) { void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha, double beta) {
int xRank = x->rankOf(); int xRank = x->rankOf();
int yRank = y->rankOf(); int yRank = y->rankOf();
@ -276,7 +276,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()})); zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
} }
mmul(xT, yT, zT, 1., 0.); mmul(xT, yT, zT, alpha, beta);
} }
else { // rest cases - batched mmul else { // rest cases - batched mmul
@ -292,7 +292,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
auto xSubArr = (*xT)(i, dimsToExclude); auto xSubArr = (*xT)(i, dimsToExclude);
auto ySubArr = (*yT)(i, dimsToExclude); auto ySubArr = (*yT)(i, dimsToExclude);
auto zSubArr = (*zT)(i, dimsToExclude); auto zSubArr = (*zT)(i, dimsToExclude);
mmul(&xSubArr, &ySubArr, &zSubArr, 1., 0.); mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta);
} }
} }

View File

@ -253,7 +253,8 @@
(45, ReversePow), \ (45, ReversePow), \
(46, DivideNoNan), \ (46, DivideNoNan), \
(47, IGamma), \ (47, IGamma), \
(48, IGammac) (48, IGammac), \
(49, RELUDerivative)

View File

@ -0,0 +1,43 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver on 6/6/2018.
//
#ifndef SD_BROADCASTABLEBOOLOP_H
#define SD_BROADCASTABLEBOOLOP_H
#include <graph/Context.h>
#include "OpDescriptor.h"
#include "DeclarableOp.h"
#include "DeclarableCustomOp.h"
namespace sd {
namespace ops {
class ND4J_EXPORT BroadcastableBoolOp : public DeclarableCustomOp{
protected:
Nd4jStatus validateAndExecute(Context& block) override = 0;
public:
BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs);
ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override;
};
}
}
#endif //SD_BROADCASTABLEBOOLOP_H

View File

@ -36,10 +36,14 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
const int iSize = (int) block.getIArguments()->size(); int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0; int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0;
// optional use alpha nad beta
iSize = (int)block.getTArguments()->size();
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
double beta = iSize > 1 ? T_ARG(1) : 0.0;
const int xRank = x->rankOf(); const int xRank = x->rankOf();
const int yRank = y->rankOf(); const int yRank = y->rankOf();
@ -77,7 +81,7 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
} }
// ******* end of input validation ******* // // ******* end of input validation ******* //
MmulHelper::matmul(x, y, z, transX, transY); MmulHelper::matmul(x, y, z, transX, transY, alpha, beta);
return Status::OK(); return Status::OK();
} }
@ -147,11 +151,17 @@ CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
auto dldx = OUTPUT_VARIABLE(0); auto dldx = OUTPUT_VARIABLE(0);
auto dldy = OUTPUT_VARIABLE(1); auto dldy = OUTPUT_VARIABLE(1);
const int iSize = (int) block.getIArguments()->size(); int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0; int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0;
// optional use alpha nad beta
iSize = (int)block.getTArguments()->size();
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
double beta = iSize > 1 ? T_ARG(1) : 0.0;
/* /*
In: x=[a,b], y=[b,c] In: x=[a,b], y=[b,c]
tX tY tZ x y z dz dLdx dLdy tX tY tZ x y z dz dLdx dLdy
@ -164,8 +174,8 @@ F F T [a,b] [b,c] [c,a] [c,a]
sd::ops::matmul op; sd::ops::matmul op;
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {}); op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {});
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {}); op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {});
return Status::OK(); return Status::OK();
} }

View File

@ -23,7 +23,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(equals, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(equals, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -66,11 +66,11 @@ namespace sd {
auto gradY = OUTPUT_VARIABLE(1); auto gradY = OUTPUT_VARIABLE(1);
gradX->assign(epsNext); gradX->assign(epsNext);
sd::ops::floormod op; NDArray temp(*epsNext);
auto tmpResult(op.evaluate({x, y})); BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp);
if (gradY->rankOf() == gradX->rankOf()) if (gradY->rankOf() == gradX->rankOf())
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult.at(0), *gradY); epsNext->applyPairwiseTransform(pairwise::Multiply, temp, *gradY);
else // epsNext is greater than gradY else // epsNext is greater than gradY
{ {
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2); std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
@ -78,7 +78,7 @@ namespace sd {
for (Nd4jLong d = 0; d < gap; d++) { for (Nd4jLong d = 0; d < gap; d++) {
dims[d * 2 + 1] = 1; dims[d * 2 + 1] = 1;
} }
auto tempIn((*tmpResult.at(0))(dims)); auto tempIn((temp)(dims));
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
} }
return Status::OK(); return Status::OK();

View File

@ -23,7 +23,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(greater, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(greater, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -22,7 +22,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(greater_equal, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(greater_equal, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -22,7 +22,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(less, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(less, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -22,7 +22,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(less_equal, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(less_equal, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -22,7 +22,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(not_equals, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(not_equals, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -80,7 +80,7 @@ namespace sd {
} }
// now once we have all strings in single vector time to fill // now once we have all strings in single vector time to fill
auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings); auto tmp = NDArrayFactory::string({ (Nd4jLong)strings.size() }, strings, input->dataType(), block.launchContext());
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
// for CUDA mostly // for CUDA mostly

View File

@ -68,6 +68,7 @@ namespace sd {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
@ -201,6 +202,7 @@ namespace sd {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -73,6 +73,7 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
@ -216,6 +217,7 @@ DECLARE_SHAPE_FN(huber_loss) {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -70,6 +70,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
@ -206,6 +207,7 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -74,6 +74,7 @@ namespace ops {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
@ -209,6 +210,7 @@ namespace ops {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -143,6 +143,7 @@ namespace sd {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
@ -282,6 +283,7 @@ namespace sd {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -67,6 +67,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
@ -200,6 +201,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -78,6 +78,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
@ -219,6 +220,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -24,10 +24,10 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/lstmLayer.h> #include<ops/declarable/helpers/lstmLayer.h>
namespace sd { namespace sd {
namespace ops { namespace ops {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
@ -43,7 +43,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc) // c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't // ct = clip(ft ◦ ct-1 + it ◦ c't)
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct) // ht = ot ◦ tanh(ct)
@ -72,26 +72,26 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// 2) [2, nOut, 4*nOut] when directionMode >= 2 // 2) [2, nOut, 4*nOut] when directionMode >= 2
// ******* // *******
// peephole weights Wp: // peephole weights Wp, optional:
// 1) [3*nOut] when directionMode < 2 // 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2 // 2) [2, 3*nOut] when directionMode >= 2
// ******* // *******
// biases b: // biases b, optional:
// 1) [4*nOut] when directionMode < 2 // 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2 // 2) [2, 4*nOut] when directionMode >= 2
// ******* // *******
// sequence length array seqLen: // sequence length array seqLen, optional:
// 1) [bS] always // 1) [bS]
// ******* // *******
// initial output hI: // initial output hI, optional:
// 1) [bS, nOut] when directionMode < 2 // 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2 // 2) [2, bS, nOut] when directionMode >= 2
// ******* // *******
// initial cell state cI (same shape as in hI): // initial cell state cI (same shape as in hI), optional:
// 1) [bS, nOut] when directionMode < 2 // 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2 // 2) [2, bS, nOut] when directionMode >= 2
@ -99,7 +99,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// OUTPUTS: // OUTPUTS:
// ******* // *******
// output h: // output h, optional:
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
@ -109,19 +109,19 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
// ******* // *******
// output at last step hL: // output at last step hL, optional:
// 1) [bS, nOut] when directionMode < 2 // 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2 // 2) [2, bS, nOut] when directionMode >= 2
// ******* // *******
// cell state at last step cL (same shape as in hL): // cell state at last step cL (same shape as in hL), optional:
// 1) [bS, nOut] when directionMode < 2 // 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2 // 2) [2, bS, nOut] when directionMode >= 2
// !!! dimension 4*nOut implies order it, ft, c't, ot // !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot // !!! dimension 3*nOut implies order it, ft, ot
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
@ -135,8 +135,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// evaluate dimensions // evaluate dimensions
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4; const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations // inputs validations
@ -323,9 +323,9 @@ DECLARE_SHAPE_FN(lstmLayer) {
const auto Wr = INPUT_VARIABLE(2); // recurrent weights const auto Wr = INPUT_VARIABLE(2); // recurrent weights
// evaluate dimensions // evaluate dimensions
const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) ); const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4; const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
DataType type; DataType type;
@ -398,6 +398,412 @@ DECLARE_SHAPE_FN(lstmLayer) {
} }
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) {
// equations (no peephole connections)
// it = σ(Wxi * xt + Wri * ht-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = clip(ft ◦ ct-1 + it ◦ c't)
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// sL - sequence length, number of time steps
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// *******
// input x:
// 1) [sL, bS, nIn] when dataFormat == 0
// 2) [bS, sL, nIn] when dataFormat == 1
// 3) [bS, nIn, sL] when dataFormat == 2
// *******
// input weights Wx:
// 1) [nIn, 4*nOut] when directionMode < 2
// 2) [2, nIn, 4*nOut] when directionMode >= 2
// *******
// recurrent weights Wr:
// 1) [nOut, 4*nOut] when directionMode < 2
// 2) [2, nOut, 4*nOut] when directionMode >= 2
// *******
// peephole weights Wp, optional:
// 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2
// *******
// biases b, optional:
// 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2
// *******
// sequence length array seqLen, optional:
// 1) [bS]
// *******
// initial output hI, optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// initial cell state cI (same shape as in hI), optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// gradient vs. output dLdh, optional:
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
// 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
// 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
// 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
// *******
// gradient vs output at last time step dLdhL, optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// gradient vs cell state at last time step dLdcL(same shape as in dLdhL), optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// OUTPUTS:
// *******
// gradient vs. input dLdx:
// 1) [sL, bS, nIn] when dataFormat == 0
// 2) [bS, sL, nIn] when dataFormat == 1
// 3) [bS, nIn, sL] when dataFormat == 2
// *******
// gradient vs. input weights dLdWx:
// 1) [nIn, 4*nOut] when directionMode < 2
// 2) [2, nIn, 4*nOut] when directionMode >= 2
// *******
// gradient vs. recurrent weights dLdWr:
// 1) [nOut, 4*nOut] when directionMode < 2
// 2) [2, nOut, 4*nOut] when directionMode >= 2
// *******
// gradient vs. peephole weights dLdWp, optional:
// 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2
// *******
// gradient vs. biases dLdb, optional:
// 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2
// gradient vs. sequence length array dLdsL, optional (do not calculate it!!!):
// 1) [bS] always
// *******
// gradient vs. initial output dLdhI, optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates
const auto cellAct = INT_ARG(3); // activation for cell state (c)
const auto outAct = INT_ARG(4); // activation for output (h)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether gradient vs. outputs is given for whole time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1}
const auto retLastH = B_ARG(6); // indicates whether gradient vs. output at last time step (dLdhL) is given
const auto retLastC = B_ARG(7); // indicates whether gradient vs. cell state at last time step (dLdcL) is given
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
const auto outActHasBeta = outAct == 3 || outAct == 6;
uint count = 1;
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode);
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_BP operation: cell clipping value should be nonnegative (>=0) !");
REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER_BP operation: please specify at least one of three input gradient arrays: dLdh, dLdhL or dLdcL !");
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
count = 3;
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
const auto dLdh = retFullSeq ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output
const auto dLdhL = retLastH ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output at last time step
const auto dLdcL = retLastC ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. cell state at last time step
count = 3;
auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. input
auto dLdWx = OUTPUT_NULLIFIED(1); // gradient vs. input weights
auto dLdWr = OUTPUT_NULLIFIED(2); // gradient vs. recurrent weights
auto dLdb = hasBiases ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. biases
auto dLdsL = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. seqLen vector, we don't calculate it !!!
auto dLdhI = hasInitH ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial output
auto dLdcI = hasInitC ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial cell state
auto dLdWp = hasPH ? OUTPUT_NULLIFIED(count) : nullptr; // gradient vs. peephole weights
// evaluate dimensions
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations
if(directionMode < 2) { // no bidirectional
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// initial output validation
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
// gradient vs. output at last time step validation
if(dLdhL != nullptr && (dLdhL->rankOf() != 2 || dLdhL->sizeAt(0) != bS || dLdhL->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str());
// gradient vs. cell state at last time step validation
if(dLdcL != nullptr && (dLdcL->rankOf() != 2 || dLdcL->sizeAt(0) != bS || dLdcL->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str());
}
else { // bidirectional
// Wx validation
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// initial output validation
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
// gradient vs. output at last time step validation
if(dLdhL != nullptr && (dLdhL->rankOf() != 3 || dLdhL->sizeAt(0) != 2 || dLdhL->sizeAt(1) != bS || dLdhL->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str());
// gradient vs. cell state at last time step validation
if(dLdcL != nullptr && (dLdcL->rankOf() != 3 || dLdcL->sizeAt(0) != 2 || dLdcL->sizeAt(1) != bS || dLdcL->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str());
}
// gradient vs. output validation
if(dLdh) {
int factor = directionMode <= 2 ? 1 : 2;
std::vector<Nd4jLong> expdLdhShape;
if(dataFormat == 0) expdLdhShape = std::vector<Nd4jLong>{sL, bS, factor*nOut};
else if(dataFormat == 1) expdLdhShape = std::vector<Nd4jLong>{bS, sL, factor*nOut};
else if(dataFormat == 2) expdLdhShape = std::vector<Nd4jLong>{bS, factor*nOut, sL};
else expdLdhShape = std::vector<Nd4jLong>{sL, 2, bS, nOut};
REQUIRE_TRUE(dLdh->isSameShape(expdLdhShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expdLdhShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
}
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip),
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
if(directionMode == 0) { // forward
helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, true, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp);
}
else if(directionMode == 1) { // backward
helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, false, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp);
}
else { // bidirectional
NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0});
NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0});
NDArray dLdWxFwd = (*dLdWx)({0,1, 0,0, 0,0});
NDArray dLdWxBwd = (*dLdWx)({1,2, 0,0, 0,0});
NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0});
NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0});
NDArray dLdWrFwd = (*dLdWr)({0,1, 0,0, 0,0});
NDArray dLdWrBwd = (*dLdWr)({1,2, 0,0, 0,0});
NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr),
*dLdhFwd(nullptr), *dLdhBwd(nullptr), *dLdhLFwd(nullptr), *dLdhLBwd(nullptr), *dLdcLFwd(nullptr), *dLdcLBwd(nullptr),
*dLdWpFwd(nullptr), *dLdWpBwd(nullptr), *dLdbFwd(nullptr), *dLdbBwd(nullptr),
*dLdhIFwd(nullptr), *dLdhIBwd(nullptr), *dLdcIFwd(nullptr), *dLdcIBwd(nullptr);
if(Wp) {
WpFwd = new NDArray((*Wp)({0,1, 0,0}));
WpBwd = new NDArray((*Wp)({1,2, 0,0}));
dLdWpFwd = new NDArray((*dLdWp)({0,1, 0,0}));
dLdWpBwd = new NDArray((*dLdWp)({1,2, 0,0}));
}
if(b) {
bFwd = new NDArray((*b)({0,1, 0,0}));
bBwd = new NDArray((*b)({1,2, 0,0}));
dLdbFwd = new NDArray((*dLdb)({0,1, 0,0}));
dLdbBwd = new NDArray((*dLdb)({1,2, 0,0}));
}
if(hI) {
hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0}));
hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0}));
dLdhIFwd = new NDArray((*dLdhI)({0,1, 0,0, 0,0}));
dLdhIBwd = new NDArray((*dLdhI)({1,2, 0,0, 0,0}));
}
if(cI) {
cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0}));
cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0}));
dLdcIFwd = new NDArray((*dLdcI)({0,1, 0,0, 0,0}));
dLdcIBwd = new NDArray((*dLdcI)({1,2, 0,0, 0,0}));
}
if(dLdhL) {
dLdhLFwd = new NDArray((*dLdhL)({0,1, 0,0, 0,0}));
dLdhLBwd = new NDArray((*dLdhL)({1,2, 0,0, 0,0}));
}
if(dLdcL) {
dLdcLFwd = new NDArray((*dLdcL)({0,1, 0,0, 0,0}));
dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0}));
}
// FIXME looks like sum (directionMode == 2) is impossible for backprop
if(dLdh) {
if(directionMode == 2) { // sum
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: mode for bidirectional sum and dLdh being present has no sense for backpropagation !");
// dLdhFwd = dLdh;
// dLdhBwd = new NDArray(dLdh->ordering(), dLdh->getShapeAsVector(), dLdh->dataType(), dLdh->getContext()); // automatically nullifies content
}
else if(directionMode == 3) { // concat
dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0}));
dLdhBwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, nOut,2*nOut}) : (*dLdh)({0,0, nOut,2*nOut, 0,0}));
}
else { // directionMode == 4
dLdhFwd = new NDArray((*dLdh)({0,0, 0,1, 0,0, 0,0}));
dLdhBwd = new NDArray((*dLdh)({0,0, 1,2, 0,0, 0,0}));
}
}
helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd);
NDArray dLdxBwd = dLdx->ulike();
helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd);
*dLdx += dLdxBwd;
delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd;
delete dLdhBwd; delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd;
delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd;
delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd;
if(dLdhFwd != dLdh)
delete dLdhFwd;
}
return Status::OK();
}
DECLARE_TYPES(lstmLayer_bp) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(lstmLayer_bp) {
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
int count = 3;
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
std::vector<Nd4jLong*> outShapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()};
if(b != nullptr)
outShapes.push_back(b->getShapeInfo());
if(seqLen != nullptr)
outShapes.push_back(seqLen->getShapeInfo());
if(hI != nullptr)
outShapes.push_back(hI->getShapeInfo());
if(cI != nullptr)
outShapes.push_back(cI->getShapeInfo());
if(Wp != nullptr)
outShapes.push_back(Wp->getShapeInfo());
return new ShapeList(outShapes);
}
} }
} }

View File

@ -0,0 +1,339 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_lstmLayerCell)
#include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/lstmLayer.h>
namespace sd {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) {
// equations (no peephole connections)
// it = σ(Wxi * xt + Wri * ht-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = clip(ft ◦ ct-1 + it ◦ c't)
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// input x: [bS, nIn] or [nIn]
// input weights Wx: [nIn, 4*nOut]
// recurrent weights Wr: [nOut, 4*nOut]
// initial (previous) output hI: [bS, nOut] or [nOut]
// initial (previous) cell state cI: [bS, nOut] or [nOut]
// biases b (optional): [4*nOut]
// peephole weights Wp (optional): [3*nOut]
// OUTPUTS:
// current output h: [bS, nOut] or [nOut]
// current cell state c: [bS, nOut] or [nOut]
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates
const auto cellAct = INT_ARG(1); // activation for cell state (c)
const auto outAct = INT_ARG(2); // activation for output (h)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
const auto outActHasBeta = outAct == 3 || outAct == 6;
uint count = 1;
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
count = 3;
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto hI = INPUT_VARIABLE(count++); // initial output
const auto cI = INPUT_VARIABLE(count++); // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !");
auto h = OUTPUT_VARIABLE(0);
auto c = OUTPUT_VARIABLE(1);
// evaluate dimensions
const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
const Nd4jLong nIn = x->sizeAt(-1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// initial output/cell validation
std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c);
return Status::OK();
}
DECLARE_TYPES(lstmLayerCell) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(lstmLayerCell) {
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
uint count = hasBiases ? 4 : 3;
const auto hI = INPUT_VARIABLE(count++); // initial output
const auto cI = INPUT_VARIABLE(count); // initial cell state
return new ShapeList({hI->getShapeInfo(), cI->getShapeInfo()});
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) {
// equations (no peephole connections)
// it = σ(Wxi * xt + Wri * ht-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = clip(ft ◦ ct-1 + it ◦ c't)
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// input x: [bS, nIn] or [nIn]
// input weights Wx: [nIn, 4*nOut]
// recurrent weights Wr: [nOut, 4*nOut]
// initial (previous) output hI: [bS, nOut] or [nOut]
// initial (previous) cell state cI: [bS, nOut] or [nOut]
// gradient wrt output dLdh: [bS, nOut] or [nOut]
// gradient wrt cell state dLdc: [bS, nOut] or [nOut]
// peephole weights Wp (optional): [3*nOut]
// biases b (optional): [4*nOut]
// OUTPUTS:
// gradient wrt x dLdx: [bS, nIn] or [nIn]
// gradient wrt Wx dLdWx: [nIn, 4*nOut]
// gradient wrt Wr dLdWr: [nOut, 4*nOut]
// gradient wrt hI dLdhI: [bS, nOut] or [nOut]
// gradient wrt cI dLdcI: [bS, nOut] or [nOut]
// gradient wrt b dLdb (optional): [4*nOut]
// gradient wrt Wp dLdWp (optional): [3*nOut]
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates
const auto cellAct = INT_ARG(1); // activation for cell state (c)
const auto outAct = INT_ARG(2); // activation for output (h)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
const auto outActHasBeta = outAct == 3 || outAct == 6;
uint count = 1;
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
count = 3;
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto hI = INPUT_VARIABLE(count++); // initial output
const auto cI = INPUT_VARIABLE(count++); // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !");
count = 3;
auto dLdx = OUTPUT_VARIABLE(0);
auto dLdWx = OUTPUT_VARIABLE(1);
auto dLdWr = OUTPUT_VARIABLE(2);
auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr;
auto dLdhI = OUTPUT_VARIABLE(count++);
auto dLdcI = OUTPUT_VARIABLE(count++);
auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr;
// evaluate dimensions
const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
const Nd4jLong nIn = x->sizeAt(-1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// initial output/cell validation
std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
if(dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
if(dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str());
std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
std::vector<Nd4jLong> zShape = x->rankOf() == 1 ? std::vector<Nd4jLong>({4*nOut}) : std::vector<Nd4jLong>({bS, 4*nOut});
NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext());
NDArray a = z.ulike();
NDArray h = cI->ulike();
NDArray c = cI->ulike();
helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c);
helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp);
return Status::OK();
}
DECLARE_TYPES(lstmLayerCellBp) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(lstmLayerCellBp) {
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
uint count = 3;
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto hI = INPUT_VARIABLE(count++); // initial output
const auto cI = INPUT_VARIABLE(count++); // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights
std::vector<Nd4jLong*> shapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()};
if(b != nullptr)
shapes.push_back(b->getShapeInfo());
shapes.push_back(hI->getShapeInfo());
shapes.push_back(cI->getShapeInfo());
if(Wp != nullptr)
shapes.push_back(Wp->getShapeInfo());
return new ShapeList(shapes);
}
}
}
#endif

View File

@ -16,6 +16,7 @@
// //
// xw_plus_b op. Created by GS <george@skymind.io> 31.01.2018 // xw_plus_b op. Created by GS <george@skymind.io> 31.01.2018
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
// //
// //
@ -29,25 +30,47 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) { CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto b = INPUT_VARIABLE(2); auto b = INPUT_VARIABLE(2);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(x->rankOf() <= 2 && y->rankOf() <= 2 && z->rankOf() <= 2, 0, "xw_plus_b: Input and Output NDArrays should have rank less or equal to 2"); if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty())
REQUIRE_TRUE(b->isVector() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input vector should have proper dimension 1x%i. " return Status::OK();
"But %i != %i.", z->sizeAt(-1), b->lengthOf(), z->sizeAt(-1));
const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false);
auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1);
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
REQUIRE_TRUE(z->rankOf() == 2, 0, "xw_plus_b: Output array should have rank equal 2, but got instead %i!", z->rankOf());
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input bias vector should be 1D and have proper dimension 1x%i."
" But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1));
// multiply x to y // multiply x to y
MmulHelper::mmul(x, y, z, 1.0, 0.0); MmulHelper::mmul(x, w, z, 1.0, 0.0);
// adding b vector // adding b vector
z->addiRowVector(*b); z->addiRowVector(*b);
if (bTranspose)
delete w;
return Status::OK(); return Status::OK();
} }
DECLARE_SHAPE_FN(xw_plus_b) { DECLARE_SHAPE_FN(xw_plus_b) {
auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), inputShape->at(1), false, false,
auto weights = INPUT_VARIABLE(1);
const int nWeightsFormat = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.getWorkspace()) : inputShape->at(1);
auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), weightsShape, false, false,
ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace()); ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace());
return SHAPELIST(CONSTANT(outputShape)); return SHAPELIST(CONSTANT(outputShape));
@ -58,6 +81,63 @@ namespace sd {
->setAllowedInputTypes(sd::DataType::ANY) ->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ ALL_FLOATS }); ->setAllowedOutputTypes({ ALL_FLOATS });
} }
CUSTOM_OP_IMPL(xw_plus_b_bp, 4, 3, false, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(2);
auto dLdz = INPUT_VARIABLE(3);
auto dLdx = OUTPUT_VARIABLE(0);
auto dLdb = OUTPUT_VARIABLE(2);
if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || dLdz->isEmpty())
return Status::OK();
const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false);
auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1);
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf());
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(-1), 0, "xw_plus_b BP: Input bias vector should be 1D and have proper dimension 1x%i."
" But got rank %i, and got length %i instead %i.", dLdz->sizeAt(-1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(-1));
auto dLdw = (bTranspose) ? new NDArray(OUTPUT_VARIABLE(1)->transpose()) : OUTPUT_VARIABLE(1);
// dLdb
dLdb->assign(dLdz->reduceAlongDimension(reduce::Sum, { 0 }));
matmul_bp mmul_bp;
mmul_bp.execute({ x, w, dLdz }, std::vector<NDArray*>{dLdx, dLdw}, {}, {}, {});
if (bTranspose) {
delete w;
delete dLdw;
}
return Status::OK();
}
DECLARE_SHAPE_FN(xw_plus_b_bp) {
Nd4jLong* xShapeInfo;
Nd4jLong* wShapeInfo;
Nd4jLong* bShapeInfo;
COPY_SHAPE(inputShape->at(0), xShapeInfo);
COPY_SHAPE(inputShape->at(1), wShapeInfo);
COPY_SHAPE(inputShape->at(2), bShapeInfo);
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(wShapeInfo), CONSTANT(bShapeInfo));
}
DECLARE_TYPES(xw_plus_b_bp) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ ALL_FLOATS });
}
} }
} }

View File

@ -29,7 +29,7 @@ namespace sd {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector()); auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector(), block.launchContext());
BROADCAST_CHECK_EMPTY(x, y, (&z0)); BROADCAST_CHECK_EMPTY(x, y, (&z0));
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);

View File

@ -45,8 +45,8 @@ namespace sd {
weights = INPUT_VARIABLE(2); weights = INPUT_VARIABLE(2);
REQUIRE_TRUE(weights->isSameShape(predictions),0, "CONFUSION_MATRIX: Weights and predictions should have equal shape"); REQUIRE_TRUE(weights->isSameShape(predictions),0, "CONFUSION_MATRIX: Weights and predictions should have equal shape");
} }
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
output->assign(0.);
int minPrediction = predictions->reduceNumber(reduce::Min).e<int>(0); int minPrediction = predictions->reduceNumber(reduce::Min).e<int>(0);
int minLabel = labels->reduceNumber(reduce::Min).e<int>(0); int minLabel = labels->reduceNumber(reduce::Min).e<int>(0);
@ -64,11 +64,7 @@ namespace sd {
DECLARE_SHAPE_FN(confusion_matrix) { DECLARE_SHAPE_FN(confusion_matrix) {
auto labels = INPUT_VARIABLE(0); auto labels = INPUT_VARIABLE(0);
auto predictions = INPUT_VARIABLE(1); auto predictions = INPUT_VARIABLE(1);
auto dtype = block.dataType(); auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
dtype = sd::DataType::INT64; // dtype - should be a param with int argument
if (block.numI() > 1)
dtype = (sd::DataType)INT_ARG(1);
int numClasses = 0; int numClasses = 0;
if (block.getIArguments()->size() > 0) { if (block.getIArguments()->size() > 0) {

View File

@ -34,7 +34,7 @@ namespace sd {
auto resVariances = OUTPUT_VARIABLE(1); auto resVariances = OUTPUT_VARIABLE(1);
// FIXME: double? // FIXME: double?
NDArray shift = NDArrayFactory::create<double>(0.); NDArray shift = NDArrayFactory::create<double>(0., block.launchContext());
if (block.getTArguments()->size() > 0) { if (block.getTArguments()->size() > 0) {
shift.assign(T_ARG(0)); shift.assign(T_ARG(0));

View File

@ -49,8 +49,8 @@ namespace sd {
bool disposable = false; bool disposable = false;
if (min == nullptr && max == nullptr && block.numT() >= 2) { if (min == nullptr && max == nullptr && block.numT() >= 2) {
min = NDArrayFactory::create_(dtype); min = NDArrayFactory::create_(dtype, block.launchContext());
max = NDArrayFactory::create_(dtype); max = NDArrayFactory::create_(dtype, block.launchContext());
min->p(0, T_ARG(0)); min->p(0, T_ARG(0));
max->p(0, T_ARG(1)); max->p(0, T_ARG(1));
disposable = true; disposable = true;

View File

@ -40,106 +40,14 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
return Status::OK(); //No op return Status::OK(); //No op
} }
if (block.width() == 1) { REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf());
auto arguments = block.getIArguments(); if (Environment::getInstance()->isDebugAndVerbose())
int argsSize = arguments->size(); nd4j_printv("Reshape: new shape", z->getShapeAsVector());
z->assign(x->reshape(z->ordering(), z->getShapeAsVector()));
int e = 1;
char order = (char) -(*arguments)[0];
if (order != 'c' && order != 'f') {
order = 'c'; //x->ordering();
e = 0;
}
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
std::vector<Nd4jLong> shapeNew;
int e2 = e;
for (; e < (int) arguments->size(); e++) {
if (arguments->at(e) == -1){
Nd4jLong shapeLength = 1;
for(; e2 < e; e2++){
shapeLength *= arguments->at(e2);
}
for(e2 = e + 1; e2 < arguments->size(); e2++){
shapeLength *= arguments->at(e2);
}
Nd4jLong realShape = x->lengthOf() / shapeLength;
shapeNew.push_back(realShape);
}
else{
shapeNew.push_back(arguments->at(e));
}
}
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
if (Environment::getInstance()->isDebugAndVerbose()) {
nd4j_printv("Reshape: new shape", shapeNew);
}
auto xr = x->reshape(order, shapeNew);
z->assign(xr);
STORE_RESULT(*z);
return Status::OK(); return Status::OK();
} else if (block.width() == 2) {
auto s = INPUT_VARIABLE(1);
char order = 'c';
if (block.numI() > 0)
order = (char) -INT_ARG(0);
std::vector<Nd4jLong> shapeNew(s->lengthOf());
for (int e = 0; e < (int) s->lengthOf(); e++) {
auto dim = s->e<Nd4jLong >(e);
if (dim == -1){
Nd4jLong shapeLength = 1;
for(int e2 = 0; e2 < e; e2++){
shapeLength *= s->e<Nd4jLong>(e2);
}
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= s->e<Nd4jLong>(e2);
}
Nd4jLong realShape = x->lengthOf() / shapeLength;
shapeNew[e] = realShape;
}
else{
shapeNew[e] = dim;
}
}
if (Environment::getInstance()->isDebugAndVerbose()) {
nd4j_printv("Reshape: new shape", shapeNew);
}
if (s->isScalar()) {
// just a scalar
z->assign(x);
} else {
// in some cases we might go away with simple memcpy call instead of assign call
if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) {
z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset());
} else {
auto xr = x->reshape(order, shapeNew);
z->assign(xr);
}
}
return Status::OK();
}
return ND4J_STATUS_BAD_INPUT;
} }
@ -151,117 +59,111 @@ DECLARE_TYPES(reshape) {
} }
DECLARE_SHAPE_FN(reshape) { DECLARE_SHAPE_FN(reshape) {
auto inp = inputShape->at(0);
// we can launch op using Int arguments const auto x = INPUT_VARIABLE(0);
if (inputShape->size() == 1) {
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
std::vector<int> *arguments = block.getIArguments();
int e = 1;
char order = (char) -(*arguments)[0];
if (order != 'c' && order != 'f') {
order = shape::order(inp);
e = 0;
}
std::vector<int> reshapeArgs;
std::vector<Nd4jLong> shapeNew; std::vector<Nd4jLong> shapeNew;
char orderNew = 'c';
int e2 = e; if (block.width() == 1) {
for (; e < (int) arguments->size(); e++) { reshapeArgs = *block.getIArguments();
if ((int) arguments->at(e) == -1){ if(!reshapeArgs.empty()) {
orderNew = (char) -reshapeArgs[0];
Nd4jLong shapeLength = 1; if(orderNew == 'c' || orderNew == 'f')
for(; e2 < e; e2 ++){ reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case
shapeLength *= arguments->at(e2);
}
for(e2 = e + 1; e2 < arguments->size(); e2++){
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= arguments->at(e2);
}
if(shapeLength == 0){
//Edge case for empty:
shapeNew.push_back(0);
} else {
//Standard case
Nd4jLong realShape = shape::length(inp) / shapeLength;
shapeNew.push_back(realShape);
} }
} }
else { else {
shapeNew.push_back(arguments->at(e)); reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
} orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c';
} }
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew))); REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
} else {
// or, with second input "as shape"
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
// special case here // Nd4jLong xLen = x->lengthOf();
if (y->isEmpty()) { // if(x->isEmpty()) {
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array"); // xLen = 1;
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp))); // for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
} // if(x->sizeAt(i) != 0)
//Special case: empty.reshape(-1) -> return empty // xLen *= x->sizeAt(i);
if (x->isEmpty()) { // }
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
Nd4jLong prod = 1;
bool hasNegs = false;
for (auto v:shapeOf) {
if (v < 0) {
hasNegs = true;
v = 0;
}
prod *= v; // for (uint i = 0; i < reshapeArgs.size(); ++i) {
}
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well"); // if (reshapeArgs[i] == -1) {
// if there are -1s - we turn them into zeros // uint shapeLength = 1, numOfZeros = 0;
if (hasNegs) {
for (int e = 0; e < shapeOf.size(); e++)
if (shapeOf[e] < 0)
shapeOf[e] = 0;
}
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data()); // for(uint j = 0; j < i; ++j)
return SHAPELIST(CONSTANT(newShape)); // if(reshapeArgs[j] != 0)
} // shapeLength *= reshapeArgs[j];
// else
// ++numOfZeros;
std::vector<Nd4jLong> shapeNew(y->lengthOf()); // for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
// REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
// if(reshapeArgs[j] != 0)
// shapeLength *= reshapeArgs[j];
// else
// ++numOfZeros;
// }
// const auto dim = xLen / shapeLength;
// if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
// shapeNew.push_back(0);
// else
// shapeNew.push_back(dim);
// }
// else
// shapeNew.push_back(reshapeArgs[i]);
// }
Nd4jLong newShapeLen = 1;
int pos = -1;
bool newShapeEmpty = false;
for (int i = 0; i < reshapeArgs.size(); ++i) {
const int dim = reshapeArgs[i];
for (int e = 0; e < (int) y->lengthOf(); e++) {
auto dim = y->e<Nd4jLong>(e);
if (dim == -1) { if (dim == -1) {
Nd4jLong shapeLength = 1; REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
for(int e2 = 0; e2 < e; e2++){ pos = i;
shapeLength *= y->e<Nd4jLong>(e2); shapeNew.push_back(1);
} }
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){ else if (dim == 0) {
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); shapeNew.push_back(0);
shapeLength *= y->e<Nd4jLong>(e2); newShapeEmpty = true;
} }
else {
if(shapeLength == 0){ shapeNew.push_back(dim);
//Edge case for empty: newShapeLen *= dim;
shapeNew[e] = 0;
} else {
Nd4jLong realShape = shape::length(inp) / shapeLength;
shapeNew[e] = realShape;
}
}else {
shapeNew[e] = dim;
} }
} }
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew)); if (pos != -1) {
Nd4jLong xLen = x->lengthOf();
if(x->isEmpty()) {
xLen = 1;
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
if(x->sizeAt(i) > 0 || !newShapeEmpty)
xLen *= x->sizeAt(i);
} }
shapeNew[pos] = xLen / newShapeLen;
} }
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew));
}
} }
} }

View File

@ -68,7 +68,7 @@ namespace ops {
if (block.getIArguments()->size() > 0) if (block.getIArguments()->size() > 0)
N = INT_ARG(0); N = INT_ARG(0);
auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0)); auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0));
NDArray* rowCounts = NDArrayFactory::create_<int>('c', {N}); //rowP->dup(); NDArray* rowCounts = NDArrayFactory::create_<int>('c', { N }, block.launchContext()); //rowP->dup();
//srowCounts->assign(0); //srowCounts->assign(0);
Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts); Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts);
rowCounts->syncToHost(); rowCounts->syncToHost();

View File

@ -22,6 +22,7 @@
#define LIBND4J_HEADERS_BROADCASTABLE_H #define LIBND4J_HEADERS_BROADCASTABLE_H
#include <ops/declarable/BroadcastableOp.h> #include <ops/declarable/BroadcastableOp.h>
#include <ops/declarable/BroadcastableBoolOp.h>
#include <ops/declarable/headers/common.h> #include <ops/declarable/headers/common.h>
#include <ops/declarable/generic/helpers/BroadcastHelper.h> #include <ops/declarable/generic/helpers/BroadcastHelper.h>
@ -261,7 +262,7 @@ namespace sd {
* *
*/ */
#if NOT_EXCLUDED(OP_equals) #if NOT_EXCLUDED(OP_equals)
DECLARE_BROADCASTABLE_OP(equals, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(equals, 0, 0);
#endif #endif
/** /**
@ -269,7 +270,7 @@ namespace sd {
* Math is: _x != _y ? (T) 1.0f : (T) 0.0f; * Math is: _x != _y ? (T) 1.0f : (T) 0.0f;
*/ */
#if NOT_EXCLUDED(OP_not_equals) #if NOT_EXCLUDED(OP_not_equals)
DECLARE_BROADCASTABLE_OP(not_equals, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(not_equals, 0, 0);
#endif #endif
/** /**
@ -277,7 +278,7 @@ namespace sd {
* Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; * Math is: _x <= _y ? (T) 1.0f : (T) 0.0f;
*/ */
#if NOT_EXCLUDED(OP_less_equal) #if NOT_EXCLUDED(OP_less_equal)
DECLARE_BROADCASTABLE_OP(less_equal, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(less_equal, 0, 0);
#endif #endif
/** /**
@ -285,7 +286,7 @@ namespace sd {
* Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; * Math is: _x >= _y ? (T) 1.0f : (T) 0.0f;
*/ */
#if NOT_EXCLUDED(OP_greater_equal) #if NOT_EXCLUDED(OP_greater_equal)
DECLARE_BROADCASTABLE_OP(greater_equal, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(greater_equal, 0, 0);
#endif #endif
/** /**
@ -293,7 +294,7 @@ namespace sd {
* Math is: _x < _y ? (T) 1.0f : (T) 0.0f; * Math is: _x < _y ? (T) 1.0f : (T) 0.0f;
*/ */
#if NOT_EXCLUDED(OP_less) #if NOT_EXCLUDED(OP_less)
DECLARE_BROADCASTABLE_OP(less, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(less, 0, 0);
#endif #endif
/** /**
@ -301,7 +302,7 @@ namespace sd {
* Math is: _x > _y ? (T) 1.0f : (T) 0.0f; * Math is: _x > _y ? (T) 1.0f : (T) 0.0f;
*/ */
#if NOT_EXCLUDED(OP_greater) #if NOT_EXCLUDED(OP_greater)
DECLARE_BROADCASTABLE_OP(greater, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(greater, 0, 0);
#endif #endif
/** /**

View File

@ -867,9 +867,12 @@ namespace sd {
* - 2D matrix MxN * - 2D matrix MxN
* - 1D vector with N elements * - 1D vector with N elements
* output value - 2D matrix NxN as multiply of matrixes and add vector * output value - 2D matrix NxN as multiply of matrixes and add vector
* Int args:
* 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul
*/ */
#if NOT_EXCLUDED(OP_xw_plus_b) #if NOT_EXCLUDED(OP_xw_plus_b)
DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0); DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0);
DECLARE_CUSTOM_OP(xw_plus_b_bp, 4, 3, false, 0, 0);
#endif #endif
/** /**

View File

@ -149,6 +149,13 @@ namespace ops {
DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2); DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2);
#endif #endif
#if NOT_EXCLUDED(OP_lstmLayerCell)
DECLARE_CUSTOM_OP(lstmLayerCell, 5, 2, false, 1, 3);
#endif
#if NOT_EXCLUDED(OP_lstmLayerCell)
DECLARE_CUSTOM_OP(lstmLayerCellBp, 7, 5, false, 1, 3);
#endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**
@ -236,6 +243,11 @@ namespace ops {
DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5); DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5);
#endif #endif
//////////////////////////////////////////////////////////////////////////
#if NOT_EXCLUDED(OP_lstmLayer)
DECLARE_CUSTOM_OP(lstmLayer_bp, 4, 1, false, 1, 5);
#endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**

View File

@ -142,7 +142,7 @@ namespace helpers {
const int rowNum = input->rows(); const int rowNum = input->rows();
const int columnNum = input->columns(); const int columnNum = input->columns();
NDArray determinant = NDArrayFactory::create<T>(1.f); NDArray determinant = NDArrayFactory::create<T>(1.f, context);
NDArray compoundMatrix = *input; // copy NDArray compoundMatrix = *input; // copy
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
permutationMatrix.setIdentity(); permutationMatrix.setIdentity();

View File

@ -39,7 +39,7 @@ namespace helpers {
template <typename T> template <typename T>
NDArray vmul(NDArray const& v, int n) NDArray vmul(NDArray const& v, int n)
{ {
NDArray res('c', {n,n}, v.dataType()); // x = matrix_new(n, n); NDArray res('c', {n,n}, v.dataType(), v.getContext()); // x = matrix_new(n, n);
T const* vBuf = v.getDataBuffer()->primaryAsT<T>(); T const* vBuf = v.getDataBuffer()->primaryAsT<T>();
T* resBuf = res.dataBuffer()->primaryAsT<T>(); T* resBuf = res.dataBuffer()->primaryAsT<T>();
auto interloop = PRAGMA_THREADS_FOR_2D { auto interloop = PRAGMA_THREADS_FOR_2D {
@ -61,7 +61,7 @@ namespace helpers {
std::vector<NDArray> q(M); std::vector<NDArray> q(M);
NDArray z = *matrix; NDArray z = *matrix;
NDArray e('c', {M}, DataTypeUtils::fromT<T>()); // two internal buffers and scalar for squared norm NDArray e('c', {M}, DataTypeUtils::fromT<T>(), Q->getContext()); // two internal buffers and scalar for squared norm
for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
e.nullify(); e.nullify();

View File

@ -69,9 +69,9 @@ namespace helpers {
auto trial = (*input)(e, dimsToExclude); auto trial = (*input)(e, dimsToExclude);
// fill up the first k elements // fill up the first k elements
NDArray topValues = NDArrayFactory::create<T>('c', {k}); NDArray topValues = NDArrayFactory::create<T>('c', {k}, input->getContext());
NDArray sortedVals = NDArrayFactory::create<T>('c', {k}); NDArray sortedVals = NDArrayFactory::create<T>('c', {k}, input->getContext());
NDArray topIndices = NDArrayFactory::create<Nd4jLong>('c', {k}); NDArray topIndices = NDArrayFactory::create<Nd4jLong>('c', {k}, input->getContext());
for (uint pos = 0; pos < k; ++pos) { for (uint pos = 0; pos < k; ++pos) {
topIndices.t<Nd4jLong>(pos) = pos; topIndices.t<Nd4jLong>(pos) = pos;
topValues.t<T>(pos) = trial.t<T>(pos); topValues.t<T>(pos) = trial.t<T>(pos);
@ -144,7 +144,7 @@ namespace helpers {
for (int i = 0; i < input->rankOf() - 1; i++) for (int i = 0; i < input->rankOf() - 1; i++)
shapeI[i] = input->sizeAt(i); shapeI[i] = input->sizeAt(i);
shapeI[input->rankOf() - 1] = k; shapeI[input->rankOf() - 1] = k;
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<Nd4jLong>(input->ordering(), shapeI)); std::unique_ptr<NDArray> indices(NDArrayFactory::create_<Nd4jLong>(input->ordering(), shapeI, context));
NDArray* values = nullptr; NDArray* values = nullptr;
int status = topKFunctor(context, input, values, indices.get(), k, true); int status = topKFunctor(context, input, values, indices.get(), k, true);
result->assign(0); result->assign(0);

View File

@ -112,7 +112,7 @@ namespace sd {
int numThreads = 256; int numThreads = 256;
int numBlocks = sd::math::nd4j_max<int>(256, sd::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads)); int numBlocks = sd::math::nd4j_max<int>(256, sd::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
int workspaceSize = numBlocks * numBins; int workspaceSize = numBlocks * numBins;
auto tmp = NDArrayFactory::create<Z>('c', {workspaceSize}); auto tmp = NDArrayFactory::create<Z>('c', {workspaceSize}, context);
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast<X*>(min_val), reinterpret_cast<X*>(max_val)); histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast<X*>(min_val), reinterpret_cast<X*>(max_val));

View File

@ -25,7 +25,7 @@ namespace ops {
namespace helpers { namespace helpers {
typedef NDArray ColorTable_t; typedef NDArray ColorTable_t;
static NDArray DefaultColorTable(int depth) { static NDArray DefaultColorTable(int depth, sd::LaunchContext* context) {
//std::vector<std::vector<float>> colorTable; //std::vector<std::vector<float>> colorTable;
const Nd4jLong kDefaultTableLength = 10; const Nd4jLong kDefaultTableLength = 10;
const Nd4jLong kDefaultChannelLength = 4; const Nd4jLong kDefaultChannelLength = 4;
@ -40,7 +40,7 @@ namespace helpers {
0, 0, 0.5, 1, // 7: navy blue 0, 0, 0.5, 1, // 7: navy blue
0, 1, 1, 1, // 8: aqua 0, 1, 1, 1, // 8: aqua
1, 0, 1, 1 // 9: fuchsia 1, 0, 1, 1 // 9: fuchsia
}, DataType::FLOAT32); }, DataType::FLOAT32, context);
if (depth == 1) { if (depth == 1) {
colorTable.assign(1.f); // all to white when black and white colors colorTable.assign(1.f); // all to white when black and white colors
@ -144,7 +144,7 @@ namespace helpers {
auto channels = images->sizeAt(3); auto channels = images->sizeAt(3);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
auto boxSize = boxes->sizeAt(1); auto boxSize = boxes->sizeAt(1);
NDArray colorsTable = DefaultColorTable(channels); NDArray colorsTable = DefaultColorTable(channels, context);
if ((colors != nullptr && colors->lengthOf() > 0)) { if ((colors != nullptr && colors->lengthOf() > 0)) {
colorsTable = *colors; colorsTable = *colors;
} }

View File

@ -188,7 +188,7 @@ namespace helpers {
static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {boxes, scales}); NDArray::prepareSpecialUse({output}, {boxes, scales});
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext()); std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()}, context)); // - 1, scales->lengthOf()); //, scales->getContext());
NDArray scores(*scales); NDArray scores(*scales);
Nd4jPointer extras[2] = {nullptr, stream}; Nd4jPointer extras[2] = {nullptr, stream};
@ -198,7 +198,7 @@ namespace helpers {
indices->tickWriteDevice(); indices->tickWriteDevice();
sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
indices->tickWriteDevice(); indices->tickWriteDevice();
NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()}); NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()}, context);
int numSelected = 0; int numSelected = 0;
int numBoxes = boxes->sizeAt(0); int numBoxes = boxes->sizeAt(0);
auto boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer()); auto boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer());
@ -347,8 +347,8 @@ namespace helpers {
scores->syncToDevice(); scores->syncToDevice();
} }
NDArray indices = NDArrayFactory::create<I>('c', {scores->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext()); NDArray indices = NDArrayFactory::create<I>('c', {scores->lengthOf()}, context); // - 1, scales->lengthOf()); //, scales->getContext());
NDArray startPositions = NDArrayFactory::create<I>('c', {scores->lengthOf()}); NDArray startPositions = NDArrayFactory::create<I>('c', {scores->lengthOf()}, context);
NDArray selectedScores(*scores); NDArray selectedScores(*scores);
Nd4jPointer extras[2] = {nullptr, stream}; Nd4jPointer extras[2] = {nullptr, stream};
auto indexBuf = indices.dataBuffer()->specialAsT<I>();///reinterpret_cast<I*>(indices->specialBuffer()); auto indexBuf = indices.dataBuffer()->specialAsT<I>();///reinterpret_cast<I*>(indices->specialBuffer());

View File

@ -598,7 +598,7 @@ namespace helpers {
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray iota('c', {n}, permutationVectors->dataType());// = NDArrayFactory::create(); // <int>('c', {n}); NDArray iota('c', {n}, permutationVectors->dataType(), context);// = NDArrayFactory::create(); // <int>('c', {n});
iota.linspace(0); iota.syncToDevice(); iota.linspace(0); iota.syncToDevice();
output->assign(input); // fill up output tensor with zeros output->assign(input); // fill up output tensor with zeros
@ -631,7 +631,7 @@ namespace helpers {
// if (dtype != DataType::DOUBLE) // if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32; // dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace()); auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1, context);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);
@ -677,7 +677,7 @@ namespace helpers {
dtype = DataType::FLOAT32; dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace()); auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1, context);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);

View File

@ -110,7 +110,7 @@ namespace helpers {
auto resR = fullMatricies?R->ulike():matrix->ulike(); auto resR = fullMatricies?R->ulike():matrix->ulike();
std::vector<NDArray> q(M); std::vector<NDArray> q(M);
NDArray z = *matrix; NDArray z = *matrix;
NDArray e('c', {M}, DataTypeUtils::fromT<T>()); // two internal buffers and scalar for squared norm NDArray e('c', {M}, DataTypeUtils::fromT<T>(), context); // two internal buffers and scalar for squared norm
for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
e.nullify(); e.nullify();
z = matrixMinor<T>(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix) z = matrixMinor<T>(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix)
@ -177,4 +177,3 @@ namespace helpers {
} }
} }
} }

View File

@ -167,8 +167,8 @@ namespace sd {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
indices->syncToHost(); indices->syncToHost();
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -209,8 +209,8 @@ namespace sd {
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
output->assign(DataTypeUtils::infOrMax<T>()); output->assign(DataTypeUtils::infOrMax<T>());
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());

View File

@ -158,8 +158,8 @@ namespace helpers {
static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -198,8 +198,8 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
@ -314,8 +314,8 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1; auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -367,8 +367,8 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1; auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);

View File

@ -161,8 +161,8 @@ namespace helpers {
static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
output->assign(DataTypeUtils::infOrMax<T>()); output->assign(DataTypeUtils::infOrMax<T>());
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -202,8 +202,8 @@ namespace helpers {
static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
output->assign(DataTypeUtils::infOrMax<T>()); output->assign(DataTypeUtils::infOrMax<T>());

View File

@ -122,8 +122,8 @@ namespace helpers {
static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
output->assign(1); output->assign(1);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -160,8 +160,8 @@ namespace helpers {
static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());

View File

@ -86,8 +86,8 @@ namespace helpers {
static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
@ -207,8 +207,8 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1; auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);

View File

@ -162,8 +162,8 @@ namespace helpers {
static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -201,8 +201,8 @@ namespace helpers {
static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());

File diff suppressed because it is too large Load Diff

View File

@ -41,7 +41,7 @@ namespace helpers {
length += array->lengthOf(); length += array->lengthOf();
pos++; pos++;
} }
NDArray arrayFull('c', {length}, sd::DataType::INT32); NDArray arrayFull('c', {length}, sd::DataType::INT32, inputList[0]->getContext());
cContext.setOutputArray(0, &arrayFull); cContext.setOutputArray(0, &arrayFull);
cContext.setIArguments(&axis, 1); cContext.setIArguments(&axis, 1);

View File

@ -22,7 +22,6 @@
#define LIBND4J_LSTMLAYER_H #define LIBND4J_LSTMLAYER_H
#include <ops/declarable/helpers/helpers.h> #include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/activations.h>
namespace sd { namespace sd {
namespace ops { namespace ops {
@ -34,6 +33,20 @@ void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArra
const std::vector<float>& params, const std::vector<float>& params,
NDArray* h, NDArray* c); NDArray* h, NDArray* c);
//////////////////////////////////////////////////////////////////////////
// this auxiliary ff should be running before backprop
void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const std::vector<float>& params,
NDArray* z, NDArray* a, NDArray* h, NDArray* c);
//////////////////////////////////////////////////////////////////////////
void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const NDArray* dLdh, const NDArray* dLdc,
const NDArray* z, const NDArray* a, const NDArray* c, const std::vector<float>& params,
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
@ -42,71 +55,11 @@ void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const ND
NDArray* h, NDArray* hL, NDArray* cL); NDArray* h, NDArray* hL, NDArray* cL);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { void ND4J_EXPORT lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp,
switch (opId) { const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
case 0: const std::vector<float>& params, const bool forward,
(const_cast<NDArray&>(x)).applyTransform(transform::Tanh, z); NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp);
break;
case 1:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELU, 0, z);
break;
case 2:
(const_cast<NDArray&>(x)).applyTransform(transform::Sigmoid, z);
break;
case 3: {
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
(const_cast<NDArray&>(x)).applyTransform(transform::Affine, z, &args);
break;
}
case 4:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELU, alpha, z);
break;
case 5:
helpers::thresholdRelu(x.getContext(), x, alpha, z);
break;
case 6: {
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
(const_cast<NDArray&>(x)).applyTransform(transform::ScaledTanh, z, &args);
break;
}
case 7:
(const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoid, z);
break;
case 8:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELU, alpha, z);
break;
case 9:
(const_cast<NDArray&>(x)).applyTransform(transform::SoftSign, z);
break;
case 10:
(const_cast<NDArray&>(x)).applyTransform(transform::SoftPlus, z);
break;
default:
throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
}
}
//////////////////////////////////////////////////////////////////////////
static FORCEINLINE NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) {
if(dataFormat == 0 || dataFormat == 3)
return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn]
if(dataFormat == 1)
return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn]
return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL]
}
//////////////////////////////////////////////////////////////////////////
static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) {
if(dataFormat == 0 || dataFormat == 3)
return t * bS + b; // TNS: shape [sL, bS, nIn]
return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL]
}
} }

View File

@ -0,0 +1,72 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver on 6/6/2018.
//
#include <system/op_boilerplate.h>
#include <system/pointercast.h>
#include <ops/declarable/BroadcastableBoolOp.h>
#include <helpers/ShapeUtils.h>
namespace sd {
namespace ops {
BroadcastableBoolOp::BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs) : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, numIArgs) {
//
}
ShapeList *BroadcastableBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) {
auto shapeList = SHAPELIST();
auto x = inputShape->at(0);
auto y = inputShape->at(1);
sd::DataType dtype = sd::DataType::BOOL;
if(shape::isEmpty(x) || shape::isEmpty(y)) {
// this is edge case, [3, 4] + [] = []
if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)));
return shapeList;
}
Nd4jLong *newshape = nullptr;
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
} else if (shape::isScalar(x) && shape::isScalar(y)) {
if (shape::rank(x) >= shape::rank(y)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
} else {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype)));
}
} else if (shape::equalsSoft(x, y)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
} else if (shape::isScalar(x) && !shape::isScalar(y)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype)));
} else if (!shape::isScalar(x) && shape::isScalar(y)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
} else if (ShapeUtils::areShapesBroadcastable(x, y)) {
Nd4jLong *newshape = nullptr;
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
} else {
// in this case we'll throw exception later
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
}
return shapeList;
}
}
}

View File

@ -368,7 +368,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !"); REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !");
REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!");
REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !");
REQUIRE_TRUE((retLastH && retLastC) || (!retLastH && !retLastC), 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !");
count = 0; count = 0;
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
@ -464,13 +464,21 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
} }
PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto x = INPUT_VARIABLE(0); // input const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights const auto Wr = INPUT_VARIABLE(2); // recurrent weights
@ -495,7 +503,15 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
DataType hLType = hL != nullptr ? hL->dataType() : xType; DataType hLType = hL != nullptr ? hL->dataType() : xType;
DataType cLType = cL != nullptr ? cL->dataType() : xType; DataType cLType = cL != nullptr ? cL->dataType() : xType;
return block.isUseMKLDNN() && ( auto featuresSupported = (cellClip == 0) //Cell clipping not supported
&& retFullSeq //Always return full sequence in case of MKL DNN
&& !hasPH //Peephole connections not supported in MKL DNN
&& !hasSeqLen //Sequence length array not supported in MKL DNN
&& dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn]
&& directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat
&& retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other)
return block.isUseMKLDNN() && featuresSupported && (
(xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) ||
(xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) || (xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) ||
(xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8)) (xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8))

View File

@ -32,7 +32,7 @@ namespace ops {
namespace platforms { namespace platforms {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY) { static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) {
// mkl works with following // mkl works with following
// [M,K] x [K,N] = [M,N] // [M,K] x [K,N] = [M,N]
@ -150,6 +150,12 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
// Create attributes (to handle alpha and beta if necessary) // Create attributes (to handle alpha and beta if necessary)
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
if (alpha != 1.f) attr.set_output_scales(0, {alpha});
if (beta != 0.f) {
dnnl::post_ops po;
po.append_sum(beta);
attr.set_post_ops(po);
}
// operation primitive description // operation primitive description
dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md); dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md);
@ -224,11 +230,16 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) {
if(x->isEmpty() || y->isEmpty()) if(x->isEmpty() || y->isEmpty())
return Status::OK(); return Status::OK();
const int iSize = (int) block.getIArguments()->size(); int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0; int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0;
// optional use alpha nad beta
iSize = (int)block.getTArguments()->size();
float alpha = iSize > 0 ? T_ARG(0) : 1.0;
float beta = iSize > 1 ? T_ARG(1) : 0.0;
const int xRank = x->rankOf(); const int xRank = x->rankOf();
const int yRank = y->rankOf(); const int yRank = y->rankOf();
const int zRank = z->rankOf(); const int zRank = z->rankOf();
@ -265,7 +276,7 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) {
} }
// ******* end of input validation ******* // // ******* end of input validation ******* //
matmulMKLDNN(x, y, z, transX, transY); matmulMKLDNN(x, y, z, transX, transY, alpha, beta);
return Status::OK(); return Status::OK();
} }
@ -276,14 +287,16 @@ PLATFORM_CHECK(matmul, ENGINE_CPU) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
const DataType xType = x->dataType(); const DataType xType = x->dataType();
const DataType yType = y->dataType(); const DataType yType = y->dataType();
const DataType zType = z->dataType(); const DataType zType = z->dataType();
float alpha = block.numT() > 0 ? T_ARG(0) : 1.0;
float beta = block.numT() > 1 ? T_ARG(1) : 0.0;
return block.isUseMKLDNN() && x->rankOf() < 3 && return !(z->ordering() == 'f' && beta != 0.f) && block.isUseMKLDNN() && x->rankOf() < 3 &&
( (
(xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) || (xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
(xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) || (xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) ||

Some files were not shown because too many files have changed in this diff Show More