commit
09f4c21059
|
@ -96,6 +96,11 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLocalExecutionDataSources() throws Exception {
|
public void testLocalExecutionDataSources() throws Exception {
|
||||||
|
|
||||||
|
@ -204,7 +209,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
||||||
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
||||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
.terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(3))
|
new MaxCandidatesCondition(3))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -251,7 +256,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||||
.candidateGenerator(candidateGenerator)
|
.candidateGenerator(candidateGenerator)
|
||||||
.dataProvider(new TestMdsDataProvider(1, 32))
|
.dataProvider(new TestMdsDataProvider(1, 32))
|
||||||
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
||||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
.terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(3))
|
new MaxCandidatesCondition(3))
|
||||||
.scoreFunction(ScoreFunctions.testSetAccuracy())
|
.scoreFunction(ScoreFunctions.testSetAccuracy())
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -49,7 +49,7 @@ check_cuda_version "$VERSION"
|
||||||
case $VERSION in
|
case $VERSION in
|
||||||
10.2)
|
10.2)
|
||||||
VERSION2="7.6"
|
VERSION2="7.6"
|
||||||
VERSION3="1.5.2"
|
VERSION3="1.5.3"
|
||||||
;;
|
;;
|
||||||
10.1)
|
10.1)
|
||||||
VERSION2="7.6"
|
VERSION2="7.6"
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.api.transform.split;
|
package org.datavec.api.transform.split;
|
||||||
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
|
|
@ -72,11 +72,11 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
|
||||||
protected File fullDir;
|
protected File fullDir;
|
||||||
|
|
||||||
protected boolean useSubset = false;
|
protected boolean useSubset = false;
|
||||||
InputSplit[] inputSplit;
|
protected InputSplit[] inputSplit;
|
||||||
|
|
||||||
public static Map<String, String> lfwData = new HashMap<>();
|
public Map<String, String> lfwData = new HashMap<>();
|
||||||
public static Map<String, String> lfwLabel = new HashMap<>();
|
public Map<String, String> lfwLabel = new HashMap<>();
|
||||||
public static Map<String, String> lfwSubsetData = new HashMap<>();
|
public Map<String, String> lfwSubsetData = new HashMap<>();
|
||||||
|
|
||||||
public LFWLoader() {
|
public LFWLoader() {
|
||||||
this(false);
|
this(false);
|
||||||
|
|
|
@ -45,15 +45,23 @@ import static org.junit.Assert.assertTrue;
|
||||||
*/
|
*/
|
||||||
public class LoaderTests {
|
public class LoaderTests {
|
||||||
|
|
||||||
|
private static void ensureDataAvailable(){
|
||||||
|
//Ensure test resources available by initializing CifarLoader and relying on auto download
|
||||||
|
boolean preProcessCifar = false;
|
||||||
|
int numExamples = 10;
|
||||||
|
int row = 28;
|
||||||
|
int col = 28;
|
||||||
|
int channels = 1;
|
||||||
|
for( boolean train : new boolean[]{true, false}){
|
||||||
|
CifarLoader loader = new CifarLoader(row, col, channels, train, preProcessCifar);
|
||||||
|
loader.next(numExamples);
|
||||||
|
}
|
||||||
|
new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42)).next();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLfwReader() throws Exception {
|
public void testLfwReader() throws Exception {
|
||||||
String subDir = "lfw-a/lfw";
|
RecordReader rr = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42));
|
||||||
File path = new File(FilenameUtils.concat(System.getProperty("user.home"), subDir));
|
|
||||||
FileSplit fileSplit = new FileSplit(path, LFWLoader.ALLOWED_FORMATS, new Random(42));
|
|
||||||
BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(42), LFWLoader.LABEL_PATTERN, 1, 1, 1);
|
|
||||||
InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1);
|
|
||||||
RecordReader rr = new ImageRecordReader(250, 250, 3, LFWLoader.LABEL_PATTERN);
|
|
||||||
rr.initialize(inputSplit[0]);
|
|
||||||
List<String> exptedLabel = rr.getLabels();
|
List<String> exptedLabel = rr.getLabels();
|
||||||
|
|
||||||
RecordReader rr2 = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42));
|
RecordReader rr2 = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42));
|
||||||
|
@ -63,6 +71,7 @@ public class LoaderTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCifarLoader() {
|
public void testCifarLoader() {
|
||||||
|
ensureDataAvailable();
|
||||||
File dir = new File(FilenameUtils.concat(System.getProperty("user.home"), "cifar/cifar-10-batches-bin"));
|
File dir = new File(FilenameUtils.concat(System.getProperty("user.home"), "cifar/cifar-10-batches-bin"));
|
||||||
CifarLoader cifar = new CifarLoader(false, dir);
|
CifarLoader cifar = new CifarLoader(false, dir);
|
||||||
assertTrue(dir.exists());
|
assertTrue(dir.exists());
|
||||||
|
@ -71,6 +80,7 @@ public class LoaderTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCifarInputStream() throws Exception {
|
public void testCifarInputStream() throws Exception {
|
||||||
|
ensureDataAvailable();
|
||||||
// check train
|
// check train
|
||||||
String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin";
|
String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin";
|
||||||
String path = FilenameUtils.concat(System.getProperty("user.home"), subDir);
|
String path = FilenameUtils.concat(System.getProperty("user.home"), subDir);
|
||||||
|
|
|
@ -23,6 +23,11 @@ import org.junit.Test;
|
||||||
|
|
||||||
public class TestDataSets extends BaseDL4JTest {
|
public class TestDataSets extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 180000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTinyImageNetExists() throws Exception {
|
public void testTinyImageNetExists() throws Exception {
|
||||||
//Simple sanity check on extracting
|
//Simple sanity check on extracting
|
||||||
|
|
|
@ -130,14 +130,6 @@ public class SameDiffConv extends SameDiffLayer {
|
||||||
|
|
||||||
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
|
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||||
|
|
||||||
SDVariable[] vars;
|
|
||||||
if(hasBias){
|
|
||||||
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
|
||||||
vars = new SDVariable[]{layerInput, w, b};
|
|
||||||
} else {
|
|
||||||
vars = new SDVariable[]{layerInput, w};
|
|
||||||
}
|
|
||||||
|
|
||||||
Conv2DConfig c = Conv2DConfig.builder()
|
Conv2DConfig c = Conv2DConfig.builder()
|
||||||
.kH(kernel[0]).kW(kernel[1])
|
.kH(kernel[0]).kW(kernel[1])
|
||||||
.pH(padding[0]).pW(padding[1])
|
.pH(padding[0]).pW(padding[1])
|
||||||
|
@ -146,7 +138,13 @@ public class SameDiffConv extends SameDiffLayer {
|
||||||
.isSameMode(this.cm == ConvolutionMode.Same)
|
.isSameMode(this.cm == ConvolutionMode.Same)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable conv = sameDiff.cnn().conv2d(vars, c); //TODO can't set name
|
SDVariable conv = null;
|
||||||
|
if(hasBias){
|
||||||
|
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||||
|
conv = sameDiff.cnn().conv2d(layerInput, w, b, c);
|
||||||
|
} else {
|
||||||
|
conv = sameDiff.cnn().conv2d(layerInput, w, c);
|
||||||
|
}
|
||||||
|
|
||||||
return activation.asSameDiff("out", sameDiff, conv);
|
return activation.asSameDiff("out", sameDiff, conv);
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,11 @@ import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class TestCheckpointListener extends BaseDL4JTest {
|
public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder tempDir = new TemporaryFolder();
|
public TemporaryFolder tempDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@ -57,7 +62,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(75,150);
|
DataSetIterator iter = new IrisDataSetIterator(25,50);
|
||||||
|
|
||||||
return new Pair<>(net, iter);
|
return new Pair<>(net, iter);
|
||||||
}
|
}
|
||||||
|
@ -178,13 +183,13 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
|
|
||||||
CheckpointListener l = new CheckpointListener.Builder(f)
|
CheckpointListener l = new CheckpointListener.Builder(f)
|
||||||
.keepLast(3)
|
.keepLast(3)
|
||||||
.saveEvery(3, TimeUnit.SECONDS)
|
.saveEvery(4, TimeUnit.SECONDS)
|
||||||
.build();
|
.build();
|
||||||
net.setListeners(l);
|
net.setListeners(l);
|
||||||
|
|
||||||
for(int i=0; i<5; i++ ){ //10 iterations total
|
for(int i=0; i<5; i++ ){ //10 iterations total
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
Thread.sleep(4000);
|
Thread.sleep(5000);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Expect models saved at iterations: 2, 4, 6, 8 (iterations 0 and 1 shoud happen before first 3 seconds is up)
|
//Expect models saved at iterations: 2, 4, 6, 8 (iterations 0 and 1 shoud happen before first 3 seconds is up)
|
||||||
|
|
|
@ -54,6 +54,11 @@ import static org.junit.Assert.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class RegressionTest100a extends BaseDL4JTest {
|
public class RegressionTest100a extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -52,6 +52,11 @@ import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class RegressionTest100b3 extends BaseDL4JTest {
|
public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -69,6 +69,11 @@ import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
public class RegressionTest100b4 extends BaseDL4JTest {
|
public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType() {
|
public DataType getDataType() {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
@ -123,7 +128,8 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
||||||
assertEquals(dtype, net.params().dataType());
|
assertEquals(dtype, net.params().dataType());
|
||||||
assertEquals("Test for dtype: " + dtypeName, outExp, outAct);
|
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||||
|
assertTrue("Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct, eq);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,11 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCustomLayer() throws Exception {
|
public void testCustomLayer() throws Exception {
|
||||||
|
|
||||||
|
@ -106,7 +111,8 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
||||||
assertEquals(dtype, net.params().dataType());
|
assertEquals(dtype, net.params().dataType());
|
||||||
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||||
assertTrue(outExp + " vs " + outAct, eq); }
|
assertTrue(outExp + " vs " + outAct, eq);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
|
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
|
||||||
<cuda.version>10.2</cuda.version>
|
<cuda.version>10.2</cuda.version>
|
||||||
<cudnn.version>7.6</cudnn.version>
|
<cudnn.version>7.6</cudnn.version>
|
||||||
<javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version>
|
<javacpp-presets.cuda.version>1.5.3</javacpp-presets.cuda.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencyManagement>
|
<dependencyManagement>
|
||||||
|
|
|
@ -96,7 +96,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@Test(expected = IllegalStateException.class)
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalStateException.class)
|
||||||
public void fileNotFoundEndToEnd() throws Exception {
|
public void fileNotFoundEndToEnd() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
||||||
importEndModelTest(modelPath, null, true, true, false, false);
|
importEndModelTest(modelPath, null, true, true, false, false);
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Getter
|
@Getter
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = true)
|
||||||
public class EvaluationCalibration extends org.nd4j.evaluation.classification.EvaluationCalibration implements org.deeplearning4j.eval.IEvaluation<org.nd4j.evaluation.classification.EvaluationCalibration> {
|
public class EvaluationCalibration extends org.nd4j.evaluation.classification.EvaluationCalibration implements org.deeplearning4j.eval.IEvaluation<org.nd4j.evaluation.classification.EvaluationCalibration> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -31,6 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@ -99,15 +100,15 @@ public class CapsuleLayer extends SameDiffLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
|
public SDVariable defineLayer(SameDiff sd, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||||
|
|
||||||
// input: [mb, inputCapsules, inputCapsuleDimensions]
|
// input: [mb, inputCapsules, inputCapsuleDimensions]
|
||||||
|
|
||||||
// [mb, inputCapsules, 1, inputCapsuleDimensions, 1]
|
// [mb, inputCapsules, 1, inputCapsuleDimensions, 1]
|
||||||
SDVariable expanded = SD.expandDims(SD.expandDims(input, 2), 4);
|
SDVariable expanded = sd.expandDims(sd.expandDims(input, 2), 4);
|
||||||
|
|
||||||
// [mb, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions, 1]
|
// [mb, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions, 1]
|
||||||
SDVariable tiled = SD.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1);
|
SDVariable tiled = sd.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1);
|
||||||
|
|
||||||
// [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions]
|
// [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions]
|
||||||
SDVariable weights = paramTable.get(WEIGHT_PARAM);
|
SDVariable weights = paramTable.get(WEIGHT_PARAM);
|
||||||
|
@ -119,13 +120,13 @@ public class CapsuleLayer extends SameDiffLayer {
|
||||||
|
|
||||||
// b is the logits of the routing procedure
|
// b is the logits of the routing procedure
|
||||||
// [mb, inputCapsules, capsules, 1, 1]
|
// [mb, inputCapsules, capsules, 1, 1]
|
||||||
SDVariable b = SD.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1));
|
SDVariable b = sd.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1));
|
||||||
|
|
||||||
for(int i = 0 ; i < routings ; i++){
|
for(int i = 0 ; i < routings ; i++){
|
||||||
|
|
||||||
// c is the coupling coefficient, i.e. the edge weight between the 2 capsules
|
// c is the coupling coefficient, i.e. the edge weight between the 2 capsules
|
||||||
// [mb, inputCapsules, capsules, 1, 1]
|
// [mb, inputCapsules, capsules, 1, 1]
|
||||||
SDVariable c = CapsuleUtils.softmax(SD, b, 2, 5);
|
SDVariable c = sd.nn.softmax(b, 2);
|
||||||
|
|
||||||
// [mb, 1, capsules, capsuleDimensions, 1]
|
// [mb, 1, capsules, capsuleDimensions, 1]
|
||||||
SDVariable s = c.times(uHat).sum(true, 1);
|
SDVariable s = c.times(uHat).sum(true, 1);
|
||||||
|
@ -135,14 +136,14 @@ public class CapsuleLayer extends SameDiffLayer {
|
||||||
|
|
||||||
// v is the per capsule activations. On the last routing iteration, this is output
|
// v is the per capsule activations. On the last routing iteration, this is output
|
||||||
// [mb, 1, capsules, capsuleDimensions, 1]
|
// [mb, 1, capsules, capsuleDimensions, 1]
|
||||||
SDVariable v = CapsuleUtils.squash(SD, s, 3);
|
SDVariable v = CapsuleUtils.squash(sd, s, 3);
|
||||||
|
|
||||||
if(i == routings - 1){
|
if(i == routings - 1){
|
||||||
return SD.squeeze(SD.squeeze(v, 1), 3);
|
return sd.squeeze(sd.squeeze(v, 1), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
// [mb, inputCapsules, capsules, capsuleDimensions, 1]
|
// [mb, inputCapsules, capsules, capsuleDimensions, 1]
|
||||||
SDVariable vTiled = SD.tile(v, 1, (int) inputCapsules, 1, 1, 1);
|
SDVariable vTiled = sd.tile(v, 1, (int) inputCapsules, 1, 1, 1);
|
||||||
|
|
||||||
// [mb, inputCapsules, capsules, 1, 1]
|
// [mb, inputCapsules, capsules, 1, 1]
|
||||||
b = b.plus(uHat.times(vTiled).sum(true, 3));
|
b = b.plus(uHat.times(vTiled).sum(true, 3));
|
||||||
|
|
|
@ -178,9 +178,11 @@ public class LocallyConnected1D extends SameDiffLayer {
|
||||||
//Note: for same mode, bottom/right padding can be 1 more than top/left padding
|
//Note: for same mode, bottom/right padding can be 1 more than top/left padding
|
||||||
//NCW format.
|
//NCW format.
|
||||||
if(cm == ConvolutionMode.Same) {
|
if(cm == ConvolutionMode.Same) {
|
||||||
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, paddingR}}, 0);
|
layerInput = sameDiff.nn().pad(layerInput,
|
||||||
|
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), 0);
|
||||||
} else {
|
} else {
|
||||||
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, padding}}, 0);
|
layerInput = sameDiff.nn().pad(layerInput,
|
||||||
|
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -184,9 +184,11 @@ public class LocallyConnected2D extends SameDiffLayer {
|
||||||
//Note: for same mode, bottom/right padding can be 1 more than top/left padding
|
//Note: for same mode, bottom/right padding can be 1 more than top/left padding
|
||||||
//NCHW format
|
//NCHW format
|
||||||
if(cm == ConvolutionMode.Same){
|
if(cm == ConvolutionMode.Same){
|
||||||
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}}, 0);
|
layerInput = sameDiff.nn().pad(layerInput,
|
||||||
|
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), 0.0);
|
||||||
} else {
|
} else {
|
||||||
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}}, 0);
|
layerInput = sameDiff.nn().pad(layerInput,
|
||||||
|
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), 0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -185,7 +185,9 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
||||||
final val R = paramTable.get(RECURRENT_WEIGHT_KEY);
|
final val R = paramTable.get(RECURRENT_WEIGHT_KEY);
|
||||||
final val b = paramTable.get(BIAS_KEY);
|
final val b = paramTable.get(BIAS_KEY);
|
||||||
|
|
||||||
SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2);
|
long[] shape = layerInput.getShape();
|
||||||
|
Preconditions.checkState(shape != null, "Null shape for input placeholder");
|
||||||
|
SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2, (int)shape[2]);
|
||||||
this.timeSteps = inputSlices.length;
|
this.timeSteps = inputSlices.length;
|
||||||
SDVariable[] outputSlices = new SDVariable[timeSteps];
|
SDVariable[] outputSlices = new SDVariable[timeSteps];
|
||||||
SDVariable prev = null;
|
SDVariable prev = null;
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
||||||
INDArray out = activateHelper(last, training, false, workspaceMgr).getFirst();
|
INDArray out = activateHelper(last, training, false, workspaceMgr).getFirst();
|
||||||
if(storeLastForTBPTT){
|
if(storeLastForTBPTT){
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||||
tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)));
|
tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)).dup());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
|
|
|
@ -45,15 +45,4 @@ public class CapsuleUtils {
|
||||||
return x.times(squaredNorm).div(squaredNorm.plus(1.0).times(scale));
|
return x.times(squaredNorm).div(squaredNorm.plus(1.0).times(scale));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute softmax along a given dimension
|
|
||||||
*/
|
|
||||||
public static SDVariable softmax(SameDiff SD, SDVariable x, int dimension, int rank){
|
|
||||||
int[] permutation = ArrayUtil.range(0, rank);
|
|
||||||
permutation[0] = dimension;
|
|
||||||
permutation[dimension] = 0;
|
|
||||||
|
|
||||||
return SD.nn.softmax(x.permute(permutation)).permute(ArrayUtil.invertPermutation(permutation));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,6 +84,13 @@
|
||||||
|
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
<profile>
|
||||||
|
<id>testresources</id>
|
||||||
|
<activation>
|
||||||
|
<activeByDefault>true</activeByDefault>
|
||||||
|
</activation>
|
||||||
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
<activation>
|
<activation>
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||||
import org.nd4j.remote.clients.serde.impl.IntegerSerde;
|
import org.nd4j.remote.clients.serde.impl.IntegerSerde;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
import javax.imageio.ImageIO;
|
import javax.imageio.ImageIO;
|
||||||
|
@ -65,7 +66,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testMlnMnist_ImageInput() throws Exception {
|
public void testMlnMnist_ImageInput() throws Exception {
|
||||||
|
|
||||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
|
||||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||||
|
|
||||||
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||||||
|
@ -129,7 +130,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testMlnMnist_ImageInput_Async() throws Exception {
|
public void testMlnMnist_ImageInput_Async() throws Exception {
|
||||||
|
|
||||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
|
||||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||||
|
|
||||||
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||||||
|
@ -198,7 +199,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testBinaryIn_BinaryOut() throws Exception {
|
public void testBinaryIn_BinaryOut() throws Exception {
|
||||||
|
|
||||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
val modelFile = Resources.asFile("models/mnist/mnist-model.zip");
|
||||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||||
|
|
||||||
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)
|
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)
|
||||||
|
|
|
@ -495,7 +495,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
|
||||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28);
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28);
|
||||||
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10));
|
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10));
|
||||||
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10));
|
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10));
|
||||||
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b));
|
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b), -1);
|
||||||
|
|
||||||
val server = new JsonModelServer.Builder<float[], Integer>(sd)
|
val server = new JsonModelServer.Builder<float[], Integer>(sd)
|
||||||
.outputSerializer( new IntSerde())
|
.outputSerializer( new IntSerde())
|
||||||
|
|
|
@ -20,6 +20,9 @@
|
||||||
<name>deeplearning4j-remote</name>
|
<name>deeplearning4j-remote</name>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
<profile>
|
||||||
|
<id>testresources</id>
|
||||||
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
|
|
|
@ -58,7 +58,7 @@ public class TestSameDiffUI extends BaseDL4JTest {
|
||||||
SDVariable b = sd.var("b", DataType.FLOAT, 1, 4);
|
SDVariable b = sd.var("b", DataType.FLOAT, 1, 4);
|
||||||
|
|
||||||
SDVariable z = in.mmul(w).add(b);
|
SDVariable z = in.mmul(w).add(b);
|
||||||
SDVariable a = sd.nn().tanh(z);
|
SDVariable a = sd.math().tanh(z);
|
||||||
|
|
||||||
LogFileWriter lfw = new LogFileWriter(f);
|
LogFileWriter lfw = new LogFileWriter(f);
|
||||||
lfw.writeGraphStructure(sd);
|
lfw.writeGraphStructure(sd);
|
||||||
|
|
|
@ -20,7 +20,10 @@ package org.deeplearning4j.integration;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||||
|
import org.deeplearning4j.integration.testcases.dl4j.*;
|
||||||
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases;
|
||||||
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
||||||
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffRNNTestCases;
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
@ -66,14 +69,36 @@ public class IntegrationTestBaselineGenerator {
|
||||||
}
|
}
|
||||||
|
|
||||||
runGeneration(
|
runGeneration(
|
||||||
SameDiffMLPTestCases.getMLPMnist()
|
|
||||||
|
// DL4J integration test cases.
|
||||||
|
|
||||||
|
// CNN1DTestCases.getCnn1dTestCaseCharRNN(),
|
||||||
|
// CNN2DTestCases.testLenetTransferDropoutRepeatability(),
|
||||||
|
//// CNN2DTestCases.getCnn2DSynthetic(),
|
||||||
|
// CNN2DTestCases.getLenetMnist(),
|
||||||
|
// CNN2DTestCases.getVGG16TransferTinyImagenet(),
|
||||||
|
// CNN2DTestCases.getYoloHouseNumbers(),
|
||||||
|
// CNN3DTestCases.getCnn3dTestCaseSynthetic(),
|
||||||
|
// MLPTestCases.getMLPMnist(),
|
||||||
|
// MLPTestCases.getMLPMoon(),
|
||||||
|
// RNNTestCases.getRnnCharacterTestCase(),
|
||||||
|
// RNNTestCases.getRnnCsvSequenceClassificationTestCase1(),
|
||||||
|
// RNNTestCases.getRnnCsvSequenceClassificationTestCase2(),
|
||||||
|
// UnsupervisedTestCases.getVAEMnistAnomaly(),
|
||||||
|
|
||||||
|
// Samediff test cases done
|
||||||
|
SameDiffMLPTestCases.getMLPMnist(),
|
||||||
|
SameDiffMLPTestCases.getMLPMoon(),
|
||||||
|
SameDiffCNNCases.getLenetMnist(),
|
||||||
|
SameDiffCNNCases.getCnn3dSynthetic(),
|
||||||
|
SameDiffRNNTestCases.getRnnCsvSequenceClassificationTestCase1()
|
||||||
);
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void runGeneration(TestCase... testCases) throws Exception {
|
private static void runGeneration(TestCase... testCases) throws Exception {
|
||||||
|
|
||||||
for( TestCase tc : testCases ) {
|
for (TestCase tc : testCases) {
|
||||||
final ModelType modelType = tc.modelType();
|
final ModelType modelType = tc.modelType();
|
||||||
|
|
||||||
//Basic validation:
|
//Basic validation:
|
||||||
|
@ -122,18 +147,18 @@ public class IntegrationTestBaselineGenerator {
|
||||||
mln = new MultiLayerNetwork(mlc);
|
mln = new MultiLayerNetwork(mlc);
|
||||||
mln.init();
|
mln.init();
|
||||||
m = mln;
|
m = mln;
|
||||||
} else if (config instanceof ComputationGraphConfiguration){
|
} else if (config instanceof ComputationGraphConfiguration) {
|
||||||
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
||||||
json = cgc.toJson();
|
json = cgc.toJson();
|
||||||
cg = new ComputationGraph(cgc);
|
cg = new ComputationGraph(cgc);
|
||||||
cg.init();
|
cg.init();
|
||||||
m = cg;
|
m = cg;
|
||||||
} else {
|
} else {
|
||||||
sd = (SameDiff)config;
|
sd = (SameDiff) config;
|
||||||
}
|
}
|
||||||
|
|
||||||
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
|
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
|
||||||
if(modelType != ModelType.SAMEDIFF) {
|
if (modelType != ModelType.SAMEDIFF) {
|
||||||
File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json"));
|
File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json"));
|
||||||
FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8);
|
FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8);
|
||||||
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
|
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
|
||||||
|
@ -147,10 +172,10 @@ public class IntegrationTestBaselineGenerator {
|
||||||
m = tc.getPretrainedModel();
|
m = tc.getPretrainedModel();
|
||||||
if (m instanceof MultiLayerNetwork) {
|
if (m instanceof MultiLayerNetwork) {
|
||||||
mln = (MultiLayerNetwork) m;
|
mln = (MultiLayerNetwork) m;
|
||||||
} else if(m instanceof ComputationGraph){
|
} else if (m instanceof ComputationGraph) {
|
||||||
cg = (ComputationGraph) m;
|
cg = (ComputationGraph) m;
|
||||||
} else {
|
} else {
|
||||||
sd = (SameDiff)m;
|
sd = (SameDiff) m;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +183,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
//Generate predictions to compare against
|
//Generate predictions to compare against
|
||||||
if (tc.isTestPredictions()) {
|
if (tc.isTestPredictions()) {
|
||||||
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
|
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
|
||||||
List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
|
List<Map<String, INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
|
||||||
// Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
// Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,7 +203,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
Nd4j.write(out, dos);
|
Nd4j.write(out, dos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if (modelType == ModelType.CG) {
|
||||||
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
||||||
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
||||||
|
|
||||||
|
@ -192,11 +217,11 @@ public class IntegrationTestBaselineGenerator {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
List<String> outNames = tc.getPredictionsNamesSameDiff();
|
List<String> outNames = tc.getPredictionsNamesSameDiff();
|
||||||
for( Map<String,INDArray> ph : inputsSd ){
|
for (Map<String, INDArray> ph : inputsSd) {
|
||||||
Map<String,INDArray> out = sd.output(ph, outNames);
|
Map<String, INDArray> out = sd.output(ph, outNames);
|
||||||
|
|
||||||
//Save the output...
|
//Save the output...
|
||||||
for(String s : outNames){
|
for (String s : outNames) {
|
||||||
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
|
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
|
||||||
try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) {
|
try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) {
|
||||||
Nd4j.write(out.get(s), dos);
|
Nd4j.write(out.get(s), dos);
|
||||||
|
@ -211,7 +236,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
//Compute and save gradients:
|
//Compute and save gradients:
|
||||||
if (tc.isTestGradients()) {
|
if (tc.isTestGradients()) {
|
||||||
INDArray gradientFlat = null;
|
INDArray gradientFlat = null;
|
||||||
Map<String,INDArray> grad;
|
Map<String, INDArray> grad;
|
||||||
if (modelType == ModelType.MLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
MultiDataSet data = tc.getGradientsTestData();
|
MultiDataSet data = tc.getGradientsTestData();
|
||||||
mln.setInput(data.getFeatures(0));
|
mln.setInput(data.getFeatures(0));
|
||||||
|
@ -220,7 +245,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
gradientFlat = mln.getFlattenedGradients();
|
gradientFlat = mln.getFlattenedGradients();
|
||||||
grad = m.gradient().gradientForVariable();
|
grad = m.gradient().gradientForVariable();
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if (modelType == ModelType.CG) {
|
||||||
MultiDataSet data = tc.getGradientsTestData();
|
MultiDataSet data = tc.getGradientsTestData();
|
||||||
cg.setInputs(data.getFeatures());
|
cg.setInputs(data.getFeatures());
|
||||||
cg.setLabels(data.getLabels());
|
cg.setLabels(data.getLabels());
|
||||||
|
@ -229,17 +254,17 @@ public class IntegrationTestBaselineGenerator {
|
||||||
gradientFlat = cg.getFlattenedGradients();
|
gradientFlat = cg.getFlattenedGradients();
|
||||||
grad = m.gradient().gradientForVariable();
|
grad = m.gradient().gradientForVariable();
|
||||||
} else {
|
} else {
|
||||||
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
|
Map<String, INDArray> ph = tc.getGradientsTestDataSameDiff();
|
||||||
List<String> allVars = new ArrayList<>();
|
List<String> allVars = new ArrayList<>();
|
||||||
for(SDVariable v : sd.variables()){
|
for (SDVariable v : sd.variables()) {
|
||||||
if(v.getVariableType() == VariableType.VARIABLE){
|
if (v.getVariableType() == VariableType.VARIABLE) {
|
||||||
allVars.add(v.name());
|
allVars.add(v.name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
grad = sd.calculateGradients(ph, allVars);
|
grad = sd.calculateGradients(ph, allVars);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(modelType != ModelType.SAMEDIFF) {
|
if (modelType != ModelType.SAMEDIFF) {
|
||||||
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
||||||
IntegrationTestRunner.write(gradientFlat, gFlatFile);
|
IntegrationTestRunner.write(gradientFlat, gFlatFile);
|
||||||
}
|
}
|
||||||
|
@ -254,25 +279,25 @@ public class IntegrationTestBaselineGenerator {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Test pretraining
|
//Test pretraining
|
||||||
if(tc.isTestUnsupervisedTraining()){
|
if (tc.isTestUnsupervisedTraining()) {
|
||||||
log.info("Performing layerwise pretraining");
|
log.info("Performing layerwise pretraining");
|
||||||
MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
|
MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
|
||||||
|
|
||||||
INDArray paramsPostTraining;
|
INDArray paramsPostTraining;
|
||||||
if(modelType == ModelType.MLN){
|
if (modelType == ModelType.MLN) {
|
||||||
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
||||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||||
|
|
||||||
for( int i : layersToTrain){
|
for (int i : layersToTrain) {
|
||||||
mln.pretrainLayer(i, dsi);
|
mln.pretrainLayer(i, dsi);
|
||||||
}
|
}
|
||||||
paramsPostTraining = mln.params();
|
paramsPostTraining = mln.params();
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if (modelType == ModelType.CG) {
|
||||||
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
||||||
|
|
||||||
for( String i : layersToTrain){
|
for (String i : layersToTrain) {
|
||||||
cg.pretrainLayer(i, iter);
|
cg.pretrainLayer(i, iter);
|
||||||
}
|
}
|
||||||
paramsPostTraining = cg.params();
|
paramsPostTraining = cg.params();
|
||||||
|
@ -290,20 +315,20 @@ public class IntegrationTestBaselineGenerator {
|
||||||
MultiDataSetIterator trainData = tc.getTrainingData();
|
MultiDataSetIterator trainData = tc.getTrainingData();
|
||||||
|
|
||||||
CollectScoresListener l = new CollectScoresListener(1);
|
CollectScoresListener l = new CollectScoresListener(1);
|
||||||
if(modelType != ModelType.SAMEDIFF)
|
if (modelType != ModelType.SAMEDIFF)
|
||||||
m.setListeners(l);
|
m.setListeners(l);
|
||||||
|
|
||||||
History h = null;
|
History h = null;
|
||||||
if (modelType == ModelType.MLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
mln.fit(trainData);
|
mln.fit(trainData);
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if (modelType == ModelType.CG) {
|
||||||
cg.fit(trainData);
|
cg.fit(trainData);
|
||||||
} else {
|
} else {
|
||||||
h = sd.fit(trainData, 1);
|
h = sd.fit(trainData, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
double[] scores;
|
double[] scores;
|
||||||
if(modelType != ModelType.SAMEDIFF){
|
if (modelType != ModelType.SAMEDIFF) {
|
||||||
scores = l.getListScore().toDoubleArray();
|
scores = l.getListScore().toDoubleArray();
|
||||||
} else {
|
} else {
|
||||||
scores = h.lossCurve().getLossValues().toDoubleVector();
|
scores = h.lossCurve().getLossValues().toDoubleVector();
|
||||||
|
@ -314,11 +339,11 @@ public class IntegrationTestBaselineGenerator {
|
||||||
FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8);
|
FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8);
|
||||||
|
|
||||||
if (tc.isTestParamsPostTraining()) {
|
if (tc.isTestParamsPostTraining()) {
|
||||||
if(modelType == ModelType.SAMEDIFF){
|
if (modelType == ModelType.SAMEDIFF) {
|
||||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
|
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
|
||||||
p.mkdirs();
|
p.mkdirs();
|
||||||
for(SDVariable v : sd.variables()){
|
for (SDVariable v : sd.variables()) {
|
||||||
if(v.getVariableType() == VariableType.VARIABLE){
|
if (v.getVariableType() == VariableType.VARIABLE) {
|
||||||
INDArray arr = v.getArr();
|
INDArray arr = v.getArr();
|
||||||
File p2 = new File(p, v.name() + ".bin");
|
File p2 = new File(p, v.name() + ".bin");
|
||||||
IntegrationTestRunner.write(arr, p2);
|
IntegrationTestRunner.write(arr, p2);
|
||||||
|
@ -331,7 +356,6 @@ public class IntegrationTestBaselineGenerator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (tc.isTestEvaluation()) {
|
if (tc.isTestEvaluation()) {
|
||||||
IEvaluation[] evals = tc.getNewEvaluations();
|
IEvaluation[] evals = tc.getNewEvaluations();
|
||||||
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
||||||
|
@ -339,7 +363,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
if (modelType == ModelType.MLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||||
mln.doEvaluation(dsi, evals);
|
mln.doEvaluation(dsi, evals);
|
||||||
} else if(modelType == ModelType.CG){
|
} else if (modelType == ModelType.CG) {
|
||||||
cg.doEvaluation(iter, evals);
|
cg.doEvaluation(iter, evals);
|
||||||
} else {
|
} else {
|
||||||
evals = tc.doEvaluationSameDiff(sd, iter, evals);
|
evals = tc.doEvaluationSameDiff(sd, iter, evals);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
package org.deeplearning4j.integration;
|
package org.deeplearning4j.integration;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases;
|
||||||
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -37,4 +38,20 @@ public class IntegrationTestsSameDiff extends BaseDL4JTest {
|
||||||
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir);
|
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMLPMoon() throws Exception {
|
||||||
|
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMoon(), testDir);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLenetMnist() throws Exception {
|
||||||
|
IntegrationTestRunner.runTest(SameDiffCNNCases.getLenetMnist(), testDir);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn3dSynthetic() throws Exception {
|
||||||
|
IntegrationTestRunner.runTest(SameDiffCNNCases.getCnn3dSynthetic(), testDir);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,6 +194,8 @@ public class CNN2DTestCases {
|
||||||
testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb)
|
testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb)
|
||||||
testEvaluation = false;
|
testEvaluation = false;
|
||||||
testOverfitting = false;
|
testOverfitting = false;
|
||||||
|
maxRelativeErrorOutput = 0.2;
|
||||||
|
minAbsErrorOutput = 0.05; //Max value is around 0.22
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -314,6 +316,7 @@ public class CNN2DTestCases {
|
||||||
ComputationGraph model = new TransferLearning.GraphBuilder(pretrained)
|
ComputationGraph model = new TransferLearning.GraphBuilder(pretrained)
|
||||||
.fineTuneConfiguration(fineTuneConf)
|
.fineTuneConfiguration(fineTuneConf)
|
||||||
.removeVertexKeepConnections("conv2d_9")
|
.removeVertexKeepConnections("conv2d_9")
|
||||||
|
.removeVertexAndConnections("outputs")
|
||||||
.addLayer("convolution2d_9",
|
.addLayer("convolution2d_9",
|
||||||
new ConvolutionLayer.Builder(1,1)
|
new ConvolutionLayer.Builder(1,1)
|
||||||
.nIn(1024)
|
.nIn(1024)
|
||||||
|
@ -393,7 +396,7 @@ public class CNN2DTestCases {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ModelType modelType() {
|
public ModelType modelType() {
|
||||||
return ModelType.CG;
|
return ModelType.MLN;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -77,6 +77,10 @@ public class MLPTestCases {
|
||||||
testOverfitting = true;
|
testOverfitting = true;
|
||||||
maxRelativeErrorOverfit = 2e-2;
|
maxRelativeErrorOverfit = 2e-2;
|
||||||
minAbsErrorOverfit = 1e-2;
|
minAbsErrorOverfit = 1e-2;
|
||||||
|
maxRelativeErrorGradients = 0.01;
|
||||||
|
minAbsErrorGradients = 0.05;
|
||||||
|
maxRelativeErrorParamsPostTraining = 0.01;
|
||||||
|
minAbsErrorParamsPostTraining = 0.05;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -135,8 +139,7 @@ public class MLPTestCases {
|
||||||
public IEvaluation[] getNewEvaluations(){
|
public IEvaluation[] getNewEvaluations(){
|
||||||
return new IEvaluation[]{
|
return new IEvaluation[]{
|
||||||
new Evaluation(),
|
new Evaluation(),
|
||||||
new ROCMultiClass(),
|
new ROCMultiClass()
|
||||||
new EvaluationCalibration()
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.shade.guava.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.deeplearning4j.integration.TestCase;
|
import org.deeplearning4j.integration.TestCase;
|
||||||
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
|
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
|
||||||
|
@ -91,7 +92,7 @@ public class RNNTestCases {
|
||||||
}
|
}
|
||||||
|
|
||||||
private int miniBatchSize = 32;
|
private int miniBatchSize = 32;
|
||||||
private int exampleLength = 1000;
|
private int exampleLength = 200;
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -101,6 +102,7 @@ public class RNNTestCases {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object getConfiguration() throws Exception {
|
public Object getConfiguration() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
||||||
int nOut = iter.totalOutcomes();
|
int nOut = iter.totalOutcomes();
|
||||||
|
@ -113,7 +115,7 @@ public class RNNTestCases {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.l2(0.001)
|
.l2(0.001)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new RmsProp(0.1))
|
.updater(new Adam(1e-3))
|
||||||
.list()
|
.list()
|
||||||
.layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
|
.layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
|
||||||
.activation(Activation.TANH).build())
|
.activation(Activation.TANH).build())
|
||||||
|
@ -140,7 +142,7 @@ public class RNNTestCases {
|
||||||
@Override
|
@Override
|
||||||
public MultiDataSetIterator getTrainingData() throws Exception {
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
DataSetIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
DataSetIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
||||||
iter = new EarlyTerminationDataSetIterator(iter, 2); //3 minibatches, 1000/200 = 5 updates per minibatch
|
iter = new EarlyTerminationDataSetIterator(iter, 2); //2 minibatches, 200/50 = 4 updates per minibatch
|
||||||
return new MultiDataSetIteratorAdapter(iter);
|
return new MultiDataSetIteratorAdapter(iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -72,12 +72,12 @@ public class UnsupervisedTestCases {
|
||||||
return new NeuralNetConfiguration.Builder()
|
return new NeuralNetConfiguration.Builder()
|
||||||
.dataType(DataType.FLOAT)
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(0.05))
|
.updater(new Adam(1e-3))
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.l2(1e-4)
|
.l2(1e-4)
|
||||||
.list()
|
.list()
|
||||||
.layer(0, new VariationalAutoencoder.Builder()
|
.layer(0, new VariationalAutoencoder.Builder()
|
||||||
.activation(Activation.LEAKYRELU)
|
.activation(Activation.TANH)
|
||||||
.encoderLayerSizes(256, 256) //2 encoder layers, each of size 256
|
.encoderLayerSizes(256, 256) //2 encoder layers, each of size 256
|
||||||
.decoderLayerSizes(256, 256) //2 decoder layers, each of size 256
|
.decoderLayerSizes(256, 256) //2 decoder layers, each of size 256
|
||||||
.pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function
|
.pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function
|
||||||
|
|
|
@ -0,0 +1,398 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.integration.testcases.samediff;
|
||||||
|
|
||||||
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||||
|
import org.deeplearning4j.integration.ModelType;
|
||||||
|
import org.deeplearning4j.integration.TestCase;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
public class SameDiffCNNCases {
|
||||||
|
|
||||||
|
|
||||||
|
public static TestCase getLenetMnist() {
|
||||||
|
return new TestCase() {
|
||||||
|
{
|
||||||
|
testName = "LenetMnistSD";
|
||||||
|
testType = TestType.RANDOM_INIT;
|
||||||
|
testPredictions = true;
|
||||||
|
testTrainingCurves = true;
|
||||||
|
testGradients = true;
|
||||||
|
testParamsPostTraining = true;
|
||||||
|
testEvaluation = true;
|
||||||
|
testOverfitting = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.SAMEDIFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Object getConfiguration() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
int nChannels = 1; // Number of input channels
|
||||||
|
int outputNum = 10; // The number of possible outcomes
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, outputNum);
|
||||||
|
|
||||||
|
//input [minibatch, channels=1, Height = 28, Width = 28]
|
||||||
|
SDVariable in4d = in.reshape(-1, nChannels, 28, 28);
|
||||||
|
|
||||||
|
int kernelHeight = 5;
|
||||||
|
int kernelWidth = 5;
|
||||||
|
|
||||||
|
|
||||||
|
// w0 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 1, outputChannels = 20]
|
||||||
|
// b0 [20]
|
||||||
|
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, nChannels, 20).muli(0.01));
|
||||||
|
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 20).muli(0.01));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer0 = sd.nn.relu(sd.cnn.conv2d("layer0", in4d, w0, b0, Conv2DConfig.builder()
|
||||||
|
.kH(kernelHeight)
|
||||||
|
.kW(kernelWidth)
|
||||||
|
.sH(1)
|
||||||
|
.sW(1)
|
||||||
|
.dataFormat("NCHW")
|
||||||
|
.build()), 0);
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W) = ( 28 - 5 + 2*0 ) / 1 + 1 = 24
|
||||||
|
// [minibatch,20,24,24]
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer1 = sd.cnn.maxPooling2d("layer1", layer0, Pooling2DConfig.builder()
|
||||||
|
.kH(2).kW(2)
|
||||||
|
.sH(2).sW(2)
|
||||||
|
.isNHWC(false)
|
||||||
|
.build());
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W) = ( 24 - 2 + 2*0 ) / 2 + 1 = 12
|
||||||
|
// [minibatch,12,12,20]
|
||||||
|
|
||||||
|
|
||||||
|
// w2 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 20, outputChannels = 50]
|
||||||
|
// b0 [50]
|
||||||
|
SDVariable w2 = sd.var("w2", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, 20, 50).muli(0.01));
|
||||||
|
SDVariable b2 = sd.var("b2", Nd4j.rand(DataType.FLOAT, 50).muli(0.01));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer2 = sd.nn.relu(sd.cnn.conv2d("layer2", layer1, w2, b2, Conv2DConfig.builder()
|
||||||
|
.kH(kernelHeight)
|
||||||
|
.kW(kernelWidth)
|
||||||
|
.sH(1)
|
||||||
|
.sW(1)
|
||||||
|
.dataFormat("NCHW")
|
||||||
|
.build()), 0);
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W) = ( 12 - 5 + 2*0 ) / 1 + 1 = 8
|
||||||
|
// [minibatch,8,8,50]
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer3 = sd.cnn.maxPooling2d("layer3", layer2, Pooling2DConfig.builder()
|
||||||
|
.kH(2).kW(2)
|
||||||
|
.sH(2).sW(2)
|
||||||
|
.isNHWC(false)
|
||||||
|
.build());
|
||||||
|
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W) = ( 8 - 2 + 2*0 ) / 2 + 1 = 4
|
||||||
|
// [minibatch,4,4,50]
|
||||||
|
|
||||||
|
int channels_height_width = 4 * 4 * 50;
|
||||||
|
SDVariable layer3_reshaped = layer3.reshape(-1, channels_height_width);
|
||||||
|
|
||||||
|
SDVariable w4 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width, 500).muli(0.01));
|
||||||
|
SDVariable b4 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 500).muli(0.01));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer4 = sd.nn.relu("layer4", layer3_reshaped.mmul(w4).add(b4), 0);
|
||||||
|
|
||||||
|
SDVariable w5 = sd.var("w5", Nd4j.rand(DataType.FLOAT, 500, outputNum));
|
||||||
|
SDVariable b5 = sd.var("b5", Nd4j.rand(DataType.FLOAT, outputNum));
|
||||||
|
|
||||||
|
SDVariable out = sd.nn.softmax("out", layer4.mmul(w5).add(b5));
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||||
|
|
||||||
|
//Also set the training configuration:
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.updater(new Adam(1e-3))
|
||||||
|
.l2(1e-3)
|
||||||
|
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
|
||||||
|
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
|
||||||
|
.build());
|
||||||
|
|
||||||
|
|
||||||
|
return sd;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||||
|
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
|
||||||
|
Map<String, INDArray> map = new HashMap<>();
|
||||||
|
map.put("in", ds.getFeatures());
|
||||||
|
map.put("label", ds.getLabels());
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
|
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
|
||||||
|
|
||||||
|
iter = new EarlyTerminationDataSetIterator(iter, 60);
|
||||||
|
return new MultiDataSetIteratorAdapter(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||||||
|
return new MultiDataSetIteratorAdapter(new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
|
DataSetIterator iter = new MnistDataSetIterator(8, true, 12345);
|
||||||
|
|
||||||
|
List<Map<String, INDArray>> list = new ArrayList<>();
|
||||||
|
|
||||||
|
org.nd4j.linalg.dataset.DataSet ds = iter.next();
|
||||||
|
ds = ds.asList().get(0);
|
||||||
|
|
||||||
|
list.add(Collections.singletonMap("in", ds.getFeatures()));
|
||||||
|
ds = iter.next();
|
||||||
|
list.add(Collections.singletonMap("in", ds.getFeatures()));
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getPredictionsNamesSameDiff() {
|
||||||
|
return Collections.singletonList("out");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] getNewEvaluations() {
|
||||||
|
return new IEvaluation[]{
|
||||||
|
new Evaluation(),
|
||||||
|
new ROCMultiClass(),
|
||||||
|
new EvaluationCalibration()};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
|
||||||
|
sd.evaluate(iter, "out", 0, evaluations);
|
||||||
|
return evaluations;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static TestCase getCnn3dSynthetic() {
|
||||||
|
return new TestCase() {
|
||||||
|
{
|
||||||
|
testName = "Cnn3dSynthetic";
|
||||||
|
testType = TestType.RANDOM_INIT;
|
||||||
|
testPredictions = true;
|
||||||
|
testTrainingCurves = true;
|
||||||
|
testGradients = true;
|
||||||
|
testParamsPostTraining = true;
|
||||||
|
testEvaluation = true;
|
||||||
|
testOverfitting = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.SAMEDIFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Object getConfiguration() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
int nChannels = 3; // Number of input channels
|
||||||
|
int outputNum = 10; // The number of possible outcomes
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
|
||||||
|
//input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8]
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, nChannels, 8, 8, 8);
|
||||||
|
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, nChannels, outputNum);
|
||||||
|
|
||||||
|
//input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8]
|
||||||
|
|
||||||
|
// Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]
|
||||||
|
// [kernelDepth = 3, kernelHeight = 3, kernelWidth = 3, inputChannels = 3, outputChannels = 8]
|
||||||
|
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 3, 3, 3, nChannels, 8));
|
||||||
|
// Optional 1D bias array with shape [outputChannels]. May be null.
|
||||||
|
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 8));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer0 = sd.nn.relu(sd.cnn.conv3d("layer0", in, w0, b0, Conv3DConfig.builder()
|
||||||
|
.kH(3)
|
||||||
|
.kW(3)
|
||||||
|
.kD(3)
|
||||||
|
.sH(2)
|
||||||
|
.sW(2)
|
||||||
|
.sD(2)
|
||||||
|
.dataFormat("NCDHW")
|
||||||
|
.build()), 0);
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W)(D) = (8 - 3 + 2*0 ) / 2 + 1 = 3
|
||||||
|
// [minibatch,8,3,3,3]
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer1 = sd.cnn.maxPooling3d("layer1", layer0, Pooling3DConfig.builder()
|
||||||
|
.kH(2).kW(2).kD(2)
|
||||||
|
.sH(2).sW(2).sD(2)
|
||||||
|
.isNCDHW(true)
|
||||||
|
.build());
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W)(D) = ( 3 - 2 + 2*0 ) / 2 + 1 = 1
|
||||||
|
// [minibatch,8,1,1,1]
|
||||||
|
|
||||||
|
|
||||||
|
int channels_height_width_depth = 8 * 1 * 1 * 1;
|
||||||
|
|
||||||
|
SDVariable layer1_reshaped = layer1.reshape(-1, channels_height_width_depth);
|
||||||
|
|
||||||
|
SDVariable w1 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width_depth, 10));
|
||||||
|
SDVariable b1 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 10));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable out = sd.nn.softmax("out", layer1_reshaped.mmul(w1).add(b1));
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||||
|
|
||||||
|
//Also set the training configuration:
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.updater(new Nesterovs(0.01, 0.9))
|
||||||
|
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
|
||||||
|
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
|
||||||
|
.build());
|
||||||
|
|
||||||
|
return sd;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String,INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
//NCDHW format
|
||||||
|
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
|
||||||
|
INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
Map<String, INDArray> map = new HashMap<>();
|
||||||
|
map.put("in", arr);
|
||||||
|
map.put("label", labels);
|
||||||
|
return map;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getPredictionsNamesSameDiff() {
|
||||||
|
|
||||||
|
return Collections.singletonList("out");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
List<Map<String, INDArray>> list = new ArrayList<>();
|
||||||
|
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
|
||||||
|
|
||||||
|
list.add(Collections.singletonMap("in", arr));
|
||||||
|
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSet getGradientsTestData() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
//NCDHW format
|
||||||
|
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
|
||||||
|
INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10);
|
||||||
|
return new org.nd4j.linalg.dataset.MultiDataSet(arr, labels);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
|
return new SingletonMultiDataSetIterator(getGradientsTestData());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||||||
|
return getTrainingData();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){
|
||||||
|
sd.evaluate(iter, "out", 0, evaluations);
|
||||||
|
return evaluations;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] getNewEvaluations(){
|
||||||
|
return new IEvaluation[]{new Evaluation()};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,31 +15,50 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.deeplearning4j.integration.testcases.samediff;
|
package org.deeplearning4j.integration.testcases.samediff;
|
||||||
|
|
||||||
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
|
import org.datavec.api.split.FileSplit;
|
||||||
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||||
import org.deeplearning4j.integration.ModelType;
|
import org.deeplearning4j.integration.ModelType;
|
||||||
import org.deeplearning4j.integration.TestCase;
|
import org.deeplearning4j.integration.TestCase;
|
||||||
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig.*;
|
||||||
|
|
||||||
public class SameDiffMLPTestCases {
|
public class SameDiffMLPTestCases {
|
||||||
|
|
||||||
|
|
||||||
public static TestCase getMLPMnist(){
|
public static TestCase getMLPMnist() {
|
||||||
return new TestCase() {
|
return new TestCase() {
|
||||||
{
|
{
|
||||||
testName = "MLPMnistSD";
|
testName = "MLPMnistSD";
|
||||||
|
@ -68,10 +87,10 @@ public class SameDiffMLPTestCases {
|
||||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
|
||||||
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
|
||||||
|
|
||||||
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256));
|
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256).muli(0.1));
|
||||||
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256));
|
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256).muli(0.1));
|
||||||
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10));
|
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10).muli(0.1));
|
||||||
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10));
|
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10).muli(0.1));
|
||||||
|
|
||||||
SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0));
|
SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0));
|
||||||
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
|
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
|
||||||
|
@ -90,7 +109,7 @@ public class SameDiffMLPTestCases {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
List<Map<String,INDArray>> out = new ArrayList<>();
|
List<Map<String, INDArray>> out = new ArrayList<>();
|
||||||
|
|
||||||
DataSetIterator iter = new MnistDataSetIterator(1, true, 12345);
|
DataSetIterator iter = new MnistDataSetIterator(1, true, 12345);
|
||||||
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
|
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
|
||||||
|
@ -109,7 +128,7 @@ public class SameDiffMLPTestCases {
|
||||||
@Override
|
@Override
|
||||||
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
|
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||||
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
|
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
|
||||||
Map<String,INDArray> map = new HashMap<>();
|
Map<String, INDArray> map = new HashMap<>();
|
||||||
map.put("in", ds.getFeatures());
|
map.put("in", ds.getFeatures());
|
||||||
map.put("label", ds.getLabels());
|
map.put("label", ds.getLabels());
|
||||||
return map;
|
return map;
|
||||||
|
@ -152,4 +171,160 @@ public class SameDiffMLPTestCases {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static TestCase getMLPMoon() {
|
||||||
|
return new TestCase() {
|
||||||
|
{
|
||||||
|
testName = "MLPMoonSD";
|
||||||
|
testType = TestType.RANDOM_INIT;
|
||||||
|
testPredictions = true;
|
||||||
|
testTrainingCurves = true;
|
||||||
|
testGradients = true;
|
||||||
|
testParamsPostTraining = true;
|
||||||
|
testEvaluation = true;
|
||||||
|
testOverfitting = true;
|
||||||
|
maxRelativeErrorOverfit = 2e-2;
|
||||||
|
minAbsErrorOverfit = 1e-2;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.SAMEDIFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object getConfiguration() throws Exception {
|
||||||
|
|
||||||
|
int numInputs = 2;
|
||||||
|
int numOutputs = 2;
|
||||||
|
int numHiddenNodes = 20;
|
||||||
|
double learningRate = 0.005;
|
||||||
|
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
//Define the network structure:
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, numInputs);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, numOutputs);
|
||||||
|
|
||||||
|
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, numInputs, numHiddenNodes));
|
||||||
|
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, numHiddenNodes));
|
||||||
|
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numHiddenNodes, numOutputs));
|
||||||
|
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numOutputs));
|
||||||
|
|
||||||
|
SDVariable a0 = sd.nn.relu(in.mmul(w0).add(b0), 0);
|
||||||
|
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||||
|
|
||||||
|
//Also set the training configuration:
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.updater(new Nesterovs(learningRate, 0.9))
|
||||||
|
.weightDecay(1e-3, true)
|
||||||
|
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
|
||||||
|
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
|
||||||
|
.build());
|
||||||
|
|
||||||
|
return sd;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
|
List<Map<String, INDArray>> out = new ArrayList<>();
|
||||||
|
|
||||||
|
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
|
||||||
|
|
||||||
|
RecordReader rr = new CSVRecordReader();
|
||||||
|
rr.initialize(new FileSplit(f));
|
||||||
|
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 0, 2);
|
||||||
|
|
||||||
|
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
|
||||||
|
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getPredictionsNamesSameDiff() throws Exception {
|
||||||
|
return Collections.singletonList("out");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||||
|
|
||||||
|
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
|
||||||
|
RecordReader rr = new CSVRecordReader();
|
||||||
|
rr.initialize(new FileSplit(f));
|
||||||
|
org.nd4j.linalg.dataset.DataSet ds = new RecordReaderDataSetIterator(rr, 5, 0, 2).next();
|
||||||
|
|
||||||
|
Map<String, INDArray> map = new HashMap<>();
|
||||||
|
map.put("in", ds.getFeatures());
|
||||||
|
map.put("label", ds.getLabels());
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
|
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_train.csv");
|
||||||
|
RecordReader rr = new CSVRecordReader();
|
||||||
|
rr.initialize(new FileSplit(f));
|
||||||
|
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2);
|
||||||
|
|
||||||
|
iter = new EarlyTerminationDataSetIterator(iter, 32);
|
||||||
|
return new MultiDataSetIteratorAdapter(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] getNewEvaluations() {
|
||||||
|
return new IEvaluation[]{
|
||||||
|
new Evaluation(),
|
||||||
|
new ROCMultiClass(),
|
||||||
|
new EvaluationCalibration()};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||||||
|
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
|
||||||
|
RecordReader rr = new CSVRecordReader();
|
||||||
|
rr.initialize(new FileSplit(f));
|
||||||
|
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2);
|
||||||
|
return new MultiDataSetIteratorAdapter(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
|
||||||
|
sd.evaluate(iter, "out", 0, evaluations);
|
||||||
|
return evaluations;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSet getOverfittingData() throws Exception {
|
||||||
|
|
||||||
|
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
|
||||||
|
RecordReader rr = new CSVRecordReader();
|
||||||
|
rr.initialize(new FileSplit(f));
|
||||||
|
return new RecordReaderDataSetIterator(rr, 1, 0, 2).next().toMultiDataSet();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOverfitNumIterations() {
|
||||||
|
return 200;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,289 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.integration.testcases.samediff;
|
||||||
|
|
||||||
|
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||||
|
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||||
|
import org.datavec.api.split.NumberedFileInputSplit;
|
||||||
|
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
|
import org.deeplearning4j.integration.ModelType;
|
||||||
|
import org.deeplearning4j.integration.TestCase;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
|
import org.nd4j.shade.guava.io.Files;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class SameDiffRNNTestCases {
|
||||||
|
|
||||||
|
public static TestCase getRnnCsvSequenceClassificationTestCase1() {
|
||||||
|
return new SameDiffRNNTestCases.RnnCsvSequenceClassificationTestCase1();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static class RnnCsvSequenceClassificationTestCase1 extends TestCase {
|
||||||
|
protected RnnCsvSequenceClassificationTestCase1() {
|
||||||
|
testName = "RnnCsvSequenceClassification1";
|
||||||
|
testType = TestType.RANDOM_INIT;
|
||||||
|
testPredictions = true;
|
||||||
|
testTrainingCurves = false;
|
||||||
|
testGradients = false;
|
||||||
|
testParamsPostTraining = false;
|
||||||
|
testEvaluation = true;
|
||||||
|
testOverfitting = false; //Not much point on this one - it already fits very well...
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected MultiDataNormalization normalizer;
|
||||||
|
|
||||||
|
protected MultiDataNormalization getNormalizer() throws Exception {
|
||||||
|
if (normalizer != null) {
|
||||||
|
return normalizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = new MultiNormalizerStandardize();
|
||||||
|
normalizer.fit(getTrainingDataUnnormalized());
|
||||||
|
|
||||||
|
return normalizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.SAMEDIFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object getConfiguration() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
|
||||||
|
int miniBatchSize = 10;
|
||||||
|
int numLabelClasses = 6;
|
||||||
|
int nIn = 60;
|
||||||
|
int numUnits = 7;
|
||||||
|
int timeSteps = 3;
|
||||||
|
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, miniBatchSize, timeSteps, nIn);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, miniBatchSize, numLabelClasses);
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits));
|
||||||
|
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits));
|
||||||
|
|
||||||
|
LSTMLayerConfig c = LSTMLayerConfig.builder()
|
||||||
|
.lstmdataformat(LSTMDataFormat.NTS)
|
||||||
|
.directionMode(LSTMDirectionMode.FWD)
|
||||||
|
.gateAct(LSTMActivations.SIGMOID)
|
||||||
|
.cellAct(LSTMActivations.TANH)
|
||||||
|
.outAct(LSTMActivations.TANH)
|
||||||
|
.retFullSequence(true)
|
||||||
|
.retLastC(true)
|
||||||
|
.retLastH(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
|
||||||
|
in, cLast, yLast, null,
|
||||||
|
LSTMLayerWeights.builder()
|
||||||
|
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
|
||||||
|
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
|
||||||
|
.peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits)))
|
||||||
|
.bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits)))
|
||||||
|
.build(),
|
||||||
|
c), c);
|
||||||
|
|
||||||
|
|
||||||
|
// Behaviour with default settings: 3d (time series) input with shape
|
||||||
|
// [miniBatchSize, vectorSize, timeSeriesLength] -> 2d output [miniBatchSize, vectorSize]
|
||||||
|
SDVariable layer0 = outputs.getOutput();
|
||||||
|
|
||||||
|
SDVariable layer1 = layer0.mean(1);
|
||||||
|
|
||||||
|
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numUnits, numLabelClasses));
|
||||||
|
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numLabelClasses));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable out = sd.nn.softmax("out", layer1.mmul(w1).add(b1));
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||||
|
|
||||||
|
//Also set the training configuration:
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.updater(new Adam(5e-2))
|
||||||
|
.l1(1e-3).l2(1e-3)
|
||||||
|
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
|
||||||
|
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
|
||||||
|
.build());
|
||||||
|
|
||||||
|
return sd;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
|
|
||||||
|
MultiDataSet mds = getTrainingData().next();
|
||||||
|
|
||||||
|
List<Map<String, INDArray>> list = new ArrayList<>();
|
||||||
|
|
||||||
|
list.add(Collections.singletonMap("in", mds.getFeatures()[0].reshape(10, 1, 60)));
|
||||||
|
//[batchsize, insize]
|
||||||
|
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getPredictionsNamesSameDiff() throws Exception {
|
||||||
|
return Collections.singletonList("out");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
|
MultiDataSetIterator iter = getTrainingDataUnnormalized();
|
||||||
|
MultiDataSetPreProcessor pp = multiDataSet -> {
|
||||||
|
INDArray l = multiDataSet.getLabels(0);
|
||||||
|
l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
|
||||||
|
multiDataSet.setLabels(0, l);
|
||||||
|
multiDataSet.setLabelsMaskArray(0, null);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));
|
||||||
|
|
||||||
|
return iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
|
||||||
|
int miniBatchSize = 10;
|
||||||
|
int numLabelClasses = 6;
|
||||||
|
|
||||||
|
File featuresDirTrain = Files.createTempDir();
|
||||||
|
File labelsDirTrain = Files.createTempDir();
|
||||||
|
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain);
|
||||||
|
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain);
|
||||||
|
|
||||||
|
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
|
||||||
|
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
|
||||||
|
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
|
||||||
|
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
|
||||||
|
|
||||||
|
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
|
||||||
|
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
|
||||||
|
|
||||||
|
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);
|
||||||
|
|
||||||
|
return iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] getNewEvaluations() {
|
||||||
|
return new IEvaluation[]{
|
||||||
|
new Evaluation(),
|
||||||
|
new ROCMultiClass(),
|
||||||
|
new EvaluationCalibration()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||||||
|
int miniBatchSize = 10;
|
||||||
|
int numLabelClasses = 6;
|
||||||
|
|
||||||
|
// File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
|
||||||
|
// File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
|
||||||
|
File featuresDirTest = Files.createTempDir();
|
||||||
|
File labelsDirTest = Files.createTempDir();
|
||||||
|
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest);
|
||||||
|
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest);
|
||||||
|
|
||||||
|
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
|
||||||
|
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
|
||||||
|
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
|
||||||
|
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
|
||||||
|
|
||||||
|
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
|
||||||
|
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
|
||||||
|
|
||||||
|
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);
|
||||||
|
|
||||||
|
MultiDataSetPreProcessor pp = multiDataSet -> {
|
||||||
|
INDArray l = multiDataSet.getLabels(0);
|
||||||
|
l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
|
||||||
|
multiDataSet.setLabels(0, l);
|
||||||
|
multiDataSet.setLabelsMaskArray(0, null);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));
|
||||||
|
|
||||||
|
return iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
|
||||||
|
sd.evaluate(iter, "out", 0, evaluations);
|
||||||
|
return evaluations;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,21 +40,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64"
|
||||||
|
|
||||||
implementation 'com.google.code.gson:gson:2.8.2'
|
implementation 'com.google.code.gson:gson:2.8.2'
|
||||||
annotationProcessor 'org.projectlombok:lombok:1.16.16'
|
annotationProcessor 'org.projectlombok:lombok:1.16.16'
|
||||||
|
|
|
@ -35,21 +35,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64"
|
||||||
```
|
```
|
||||||
|
|
||||||
Compiling these dependencies involves a large number of files, thus it is necessary to set multiDexEnabled to true in defaultConfig.
|
Compiling these dependencies involves a large number of files, thus it is necessary to set multiDexEnabled to true in defaultConfig.
|
||||||
|
|
|
@ -43,21 +43,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64"
|
||||||
testimplementation 'junit:junit:4.12'
|
testimplementation 'junit:junit:4.12'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -46,21 +46,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3'
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86"
|
||||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
|
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64"
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ Alternatively, in the case of CUDA 10.2, cuDNN comes bundled with the "redist" p
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.bytedeco</groupId>
|
<groupId>org.bytedeco</groupId>
|
||||||
<artifactId>cuda-platform-redist</artifactId>
|
<artifactId>cuda-platform-redist</artifactId>
|
||||||
<version>10.2-7.6-1.5.2</version>
|
<version>10.2-7.6-1.5.3</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
Also note that, by default, Deeplearning4j will use the fastest algorithms available according to cuDNN, but memory usage may be excessive, causing strange launch errors. When this happens, try to reduce memory usage by using the [`NO_WORKSPACE` mode settable via the network configuration](/api/{{page.version}}/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.Builder.html#cudnnAlgoMode-org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode-), instead of the default of `ConvolutionLayer.AlgoMode.PREFER_FASTEST`, for example:
|
Also note that, by default, Deeplearning4j will use the fastest algorithms available according to cuDNN, but memory usage may be excessive, causing strange launch errors. When this happens, try to reduce memory usage by using the [`NO_WORKSPACE` mode settable via the network configuration](/api/{{page.version}}/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.Builder.html#cudnnAlgoMode-org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode-), instead of the default of `ConvolutionLayer.AlgoMode.PREFER_FASTEST`, for example:
|
||||||
|
|
|
@ -5,7 +5,7 @@ project(mkldnn-download NONE)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
ExternalProject_Add(mkldnn
|
ExternalProject_Add(mkldnn
|
||||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||||
GIT_TAG v1.2.2
|
GIT_TAG v1.3
|
||||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
|
|
|
@ -403,7 +403,6 @@ NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, sd::Launch
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////
|
||||||
// u8 string constructors
|
// u8 string constructors
|
||||||
/////////////////////////////////////////////////////////////////////////
|
|
||||||
NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) {
|
NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) {
|
||||||
|
|
||||||
if (!DataTypeUtils::isS(dtype)) {
|
if (!DataTypeUtils::isS(dtype)) {
|
||||||
|
@ -1944,7 +1943,7 @@ void NDArray::tilei(const std::vector<Nd4jLong>& reps) {
|
||||||
Nd4jLong NDArray::sizeAt(const int dim) const {
|
Nd4jLong NDArray::sizeAt(const int dim) const {
|
||||||
|
|
||||||
if (dim >= this->rankOf() || dim < -this->rankOf())
|
if (dim >= this->rankOf() || dim < -this->rankOf())
|
||||||
throw std::runtime_error("Bad size index requested");
|
throw std::runtime_error("NDArray::sizeAt: bad size index requested");
|
||||||
|
|
||||||
if (dim >= 0)
|
if (dim >= 0)
|
||||||
return shape::shapeOf(_shapeInfo)[dim];
|
return shape::shapeOf(_shapeInfo)[dim];
|
||||||
|
|
|
@ -10,7 +10,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
||||||
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
|
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
|
||||||
|
|
||||||
if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
|
if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
|
||||||
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
|
nd4j_printf("applyTriplewiseLambda requires all operands to have the same shape\n","");
|
||||||
throw std::runtime_error("Shapes mismach");
|
throw std::runtime_error("Shapes mismach");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include <array/NDArrayList.h>
|
#include <array/NDArrayList.h>
|
||||||
#include <helpers/ShapeUtils.h>
|
#include <helpers/ShapeUtils.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include<ops/declarable/helpers/stack.h>
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
NDArrayList::NDArrayList(int height, bool expandable) {
|
NDArrayList::NDArrayList(int height, bool expandable) {
|
||||||
|
@ -144,25 +145,38 @@ namespace sd {
|
||||||
|
|
||||||
NDArray* NDArrayList::stack() {
|
NDArray* NDArrayList::stack() {
|
||||||
// FIXME: this is bad for perf, but ok as poc
|
// FIXME: this is bad for perf, but ok as poc
|
||||||
sd::ops::stack op;
|
|
||||||
std::vector<NDArray*> inputs;
|
|
||||||
std::vector<double> targs;
|
|
||||||
std::vector<Nd4jLong> iargs({0});
|
|
||||||
std::vector<bool> bargs;
|
|
||||||
int numElements = _elements.load();
|
int numElements = _elements.load();
|
||||||
|
std::vector<const NDArray*> inputs(numElements);
|
||||||
for (int e = 0; e < numElements; e++) {
|
for (int e = 0; e < numElements; e++) {
|
||||||
_chunks[e]->syncToDevice();
|
_chunks[e]->syncToDevice();
|
||||||
inputs.emplace_back(_chunks[e]);
|
inputs[e] = _chunks[e];
|
||||||
}
|
}
|
||||||
|
|
||||||
iargs.push_back(_axis);
|
auto inShapeInfo = inputs[0]->getShapeInfo();
|
||||||
|
int rank = shape::rank(inShapeInfo);
|
||||||
|
NDArray* array = nullptr;
|
||||||
|
|
||||||
auto result = op.evaluate(inputs);
|
if (shape::isEmpty(inShapeInfo)) {
|
||||||
|
switch (rank) {
|
||||||
|
case 0: {
|
||||||
|
if (numElements == 1) {
|
||||||
|
array = new NDArray(inputs[0]->ordering(), {0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext());
|
||||||
|
} else {
|
||||||
|
array = new NDArray('c', {(Nd4jLong) numElements, 0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext() ) ;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
std::vector<Nd4jLong> outShape(inShapeInfo + 1, inShapeInfo + 1 + rank);
|
||||||
|
outShape.insert(outShape.begin(), (Nd4jLong) numElements);
|
||||||
|
array = new NDArray( shape::order(inShapeInfo), outShape, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
ops::helpers::stack(inputs[0]->getContext(), inputs, *array, 0);
|
||||||
|
|
||||||
auto array = new NDArray(result.at(0)->dup());
|
return array;
|
||||||
|
|
||||||
return array;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<int,int>& NDArrayList::id() {
|
std::pair<int,int>& NDArrayList::id() {
|
||||||
|
|
|
@ -47,13 +47,13 @@ class ND4J_EXPORT GradCheck {
|
||||||
* opBP - back propagation operation
|
* opBP - back propagation operation
|
||||||
* argsHolderFF - argument holder for feed forward operation
|
* argsHolderFF - argument holder for feed forward operation
|
||||||
* argsHolderBP - argument holder for back propagation operation
|
* argsHolderBP - argument holder for back propagation operation
|
||||||
* whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty array which means to check all arrays
|
* whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays
|
||||||
* IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.}
|
* IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.}
|
||||||
* loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM
|
* loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM
|
||||||
|
* outArrsFFIdx - contains indexes of ff output arrays which are independent from each other, default means all are independent
|
||||||
*/
|
*/
|
||||||
static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
|
static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
|
||||||
const std::vector<bool>& whatArrsToCheck = std::vector<bool>(), const std::vector<double>& IdxRange = {0., 1.}, const LossFunc loss = SUM);
|
const std::vector<bool>& whatArrsToCheck = std::vector<bool>(), const std::vector<double>& IdxRange = {0., 1.}, const LossFunc loss = SUM, const std::vector<int>& outArrsFFIdx = {});
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,16 +35,16 @@ namespace sd {
|
||||||
|
|
||||||
|
|
||||||
class ND4J_EXPORT LoopKind {
|
class ND4J_EXPORT LoopKind {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D };
|
enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D };
|
||||||
|
|
||||||
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -59,8 +59,8 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd
|
||||||
|
|
||||||
int temp;
|
int temp;
|
||||||
const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c';
|
const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c';
|
||||||
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';
|
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';
|
||||||
const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo);
|
const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
if (xEws == 1 && zEws == 1 && xOrder == zOrder && (shapesSame || xOrder == 'c'))
|
if (xEws == 1 && zEws == 1 && xOrder == zOrder && (shapesSame || xOrder == 'c'))
|
||||||
return EWS1;
|
return EWS1;
|
||||||
|
@ -160,7 +160,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const N
|
||||||
const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c';
|
const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c';
|
||||||
const bool yVectorOrC = shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c';
|
const bool yVectorOrC = shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c';
|
||||||
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';
|
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';
|
||||||
const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo);
|
const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c'))
|
if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c'))
|
||||||
return EWS1;
|
return EWS1;
|
||||||
|
@ -206,7 +206,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const
|
||||||
const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c';
|
const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c';
|
||||||
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';;
|
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';;
|
||||||
|
|
||||||
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && shape::rank(xShapeInfo) == 2 && xEws == 1 && xOrder == 'c' && xRank == 2 &&
|
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 &&
|
||||||
tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
||||||
return SMALLARR2DX;
|
return SMALLARR2DX;
|
||||||
if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
||||||
|
@ -233,18 +233,18 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo) {
|
LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo) {
|
||||||
|
|
||||||
// both tad shapes are the same, but strides and ews may be different
|
// both tad shapes are the same, but strides and ews may be different
|
||||||
|
|
||||||
const int tadRank = shape::rank(xTadShapeInfo);
|
const int tadRank = shape::rank(xTadShapeInfo);
|
||||||
|
|
||||||
const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo);
|
const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo);
|
||||||
const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo);
|
const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo);
|
||||||
const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo);
|
const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
|
|
||||||
const char xTadOrder = shape::order(xTadShapeInfo);
|
const char xTadOrder = shape::order(xTadShapeInfo);
|
||||||
const char yTadOrder = shape::order(xTadShapeInfo);
|
const char yTadOrder = shape::order(xTadShapeInfo);
|
||||||
const char zOrder = shape::order(zShapeInfo);
|
const char zOrder = shape::order(zShapeInfo);
|
||||||
|
|
||||||
int position;
|
int position;
|
||||||
const bool xTadVectorOrC = shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c';
|
const bool xTadVectorOrC = shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c';
|
||||||
const bool yTadVectorOrC = shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c';
|
const bool yTadVectorOrC = shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c';
|
||||||
|
@ -265,7 +265,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, c
|
||||||
return RANK4;
|
return RANK4;
|
||||||
if(tadRank == 5 && zEws > 0 && zVectorOrC)
|
if(tadRank == 5 && zEws > 0 && zVectorOrC)
|
||||||
return RANK5;
|
return RANK5;
|
||||||
return COMMON;
|
return COMMON;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ namespace sd {
|
||||||
static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB);
|
static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY);
|
static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha = 1.0, double beta = 0.0);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -372,16 +372,16 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
|
||||||
int xLenDim(0), yLenDim(0);
|
int xLenDim(0), yLenDim(0);
|
||||||
|
|
||||||
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim))
|
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim))
|
||||||
throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !");
|
throw std::runtime_error("MmulHelper::dot: X array must be vector !");
|
||||||
if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim))
|
if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim))
|
||||||
throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !");
|
throw std::runtime_error("MmulHelper::dot: Y array must be vector !");
|
||||||
if(Z != nullptr && !Z->isScalar())
|
if(Z != nullptr && !Z->isScalar())
|
||||||
throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !");
|
throw std::runtime_error("MmulHelper::dot: Z array must be scalar !");
|
||||||
|
|
||||||
const auto length = X->lengthOf();
|
const auto length = X->lengthOf();
|
||||||
|
|
||||||
if(Y->lengthOf() != length)
|
if(Y->lengthOf() != length)
|
||||||
throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !");
|
throw std::runtime_error("MmulHelper::dot: lengths of input vectors are different !");
|
||||||
|
|
||||||
if(Z == nullptr)
|
if(Z == nullptr)
|
||||||
Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext());
|
Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext());
|
||||||
|
|
|
@ -49,7 +49,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
|
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
|
||||||
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss ) {
|
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss, const std::vector<int>& outArrsFFIdx) {
|
||||||
|
|
||||||
const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP
|
const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP
|
||||||
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
|
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
|
||||||
|
@ -82,12 +82,23 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
int numOutArrs = outArrsFF.size();
|
int numOutArrs = outArrsFF.size();
|
||||||
double scorePlus = 0.;
|
double scorePlus = 0.;
|
||||||
|
|
||||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
if(!outArrsFFIdx.empty()) {
|
||||||
if(loss == SUM)
|
for(const auto& k : outArrsFFIdx) { // loop through independent output arrays
|
||||||
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
if(loss == SUM)
|
||||||
else
|
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||||
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
else
|
||||||
scorePlus += tmpScalar.e<double>(0);
|
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||||
|
scorePlus += tmpScalar.e<double>(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||||
|
if(loss == SUM)
|
||||||
|
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||||
|
else
|
||||||
|
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||||
|
scorePlus += tmpScalar.e<double>(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// subtract epsilon, feed forward
|
// subtract epsilon, feed forward
|
||||||
|
@ -95,12 +106,23 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
outArrsFF = opFF.execute(argsHolderFF);
|
outArrsFF = opFF.execute(argsHolderFF);
|
||||||
double scoreMinus = 0.;
|
double scoreMinus = 0.;
|
||||||
|
|
||||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
if(!outArrsFFIdx.empty()) {
|
||||||
if(loss == SUM)
|
for(const auto& k : outArrsFFIdx) { // loop through independent output arrays
|
||||||
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
if(loss == SUM)
|
||||||
else
|
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||||
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
else
|
||||||
scoreMinus += tmpScalar.e<double>(0);
|
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||||
|
scoreMinus += tmpScalar.e<double>(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||||
|
if(loss == SUM)
|
||||||
|
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||||
|
else
|
||||||
|
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||||
|
scoreMinus += tmpScalar.e<double>(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// restore initial element value
|
// restore initial element value
|
||||||
|
@ -120,7 +142,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
throw std::runtime_error("");
|
throw std::runtime_error("");
|
||||||
}
|
}
|
||||||
|
|
||||||
// printf("num = %.5f, ana = %.5f\n", numericalGrad, analyticGrad);
|
// printf("%lld: num = %.15f, ana = %.15f\n", j, numericalGrad, analyticGrad);
|
||||||
|
|
||||||
// calculate relative error
|
// calculate relative error
|
||||||
double relError;
|
double relError;
|
||||||
|
@ -134,7 +156,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
|
|
||||||
if(math::nd4j_abs<double>(analyticGrad - numericalGrad) < MINABSERR)
|
if(math::nd4j_abs<double>(analyticGrad - numericalGrad) < MINABSERR)
|
||||||
continue;
|
continue;
|
||||||
printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad);
|
printf("numericalGrad = %.15f, analyticGrad = %.15f \n", numericalGrad, analyticGrad);
|
||||||
printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j);
|
printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -239,7 +239,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY) {
|
void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha, double beta) {
|
||||||
int xRank = x->rankOf();
|
int xRank = x->rankOf();
|
||||||
int yRank = y->rankOf();
|
int yRank = y->rankOf();
|
||||||
|
|
||||||
|
@ -276,7 +276,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
|
||||||
zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
|
zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
|
||||||
}
|
}
|
||||||
|
|
||||||
mmul(xT, yT, zT, 1., 0.);
|
mmul(xT, yT, zT, alpha, beta);
|
||||||
}
|
}
|
||||||
else { // rest cases - batched mmul
|
else { // rest cases - batched mmul
|
||||||
|
|
||||||
|
@ -292,7 +292,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
|
||||||
auto xSubArr = (*xT)(i, dimsToExclude);
|
auto xSubArr = (*xT)(i, dimsToExclude);
|
||||||
auto ySubArr = (*yT)(i, dimsToExclude);
|
auto ySubArr = (*yT)(i, dimsToExclude);
|
||||||
auto zSubArr = (*zT)(i, dimsToExclude);
|
auto zSubArr = (*zT)(i, dimsToExclude);
|
||||||
mmul(&xSubArr, &ySubArr, &zSubArr, 1., 0.);
|
mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -253,7 +253,8 @@
|
||||||
(45, ReversePow), \
|
(45, ReversePow), \
|
||||||
(46, DivideNoNan), \
|
(46, DivideNoNan), \
|
||||||
(47, IGamma), \
|
(47, IGamma), \
|
||||||
(48, IGammac)
|
(48, IGammac), \
|
||||||
|
(49, RELUDerivative)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver on 6/6/2018.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SD_BROADCASTABLEBOOLOP_H
|
||||||
|
#define SD_BROADCASTABLEBOOLOP_H
|
||||||
|
|
||||||
|
#include <graph/Context.h>
|
||||||
|
#include "OpDescriptor.h"
|
||||||
|
#include "DeclarableOp.h"
|
||||||
|
#include "DeclarableCustomOp.h"
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
class ND4J_EXPORT BroadcastableBoolOp : public DeclarableCustomOp{
|
||||||
|
protected:
|
||||||
|
Nd4jStatus validateAndExecute(Context& block) override = 0;
|
||||||
|
public:
|
||||||
|
BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs);
|
||||||
|
|
||||||
|
ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //SD_BROADCASTABLEBOOLOP_H
|
|
@ -36,10 +36,14 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const int iSize = (int) block.getIArguments()->size();
|
int iSize = (int) block.getIArguments()->size();
|
||||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||||
|
// optional use alpha nad beta
|
||||||
|
iSize = (int)block.getTArguments()->size();
|
||||||
|
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
|
||||||
|
double beta = iSize > 1 ? T_ARG(1) : 0.0;
|
||||||
|
|
||||||
const int xRank = x->rankOf();
|
const int xRank = x->rankOf();
|
||||||
const int yRank = y->rankOf();
|
const int yRank = y->rankOf();
|
||||||
|
@ -77,7 +81,7 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
||||||
}
|
}
|
||||||
// ******* end of input validation ******* //
|
// ******* end of input validation ******* //
|
||||||
|
|
||||||
MmulHelper::matmul(x, y, z, transX, transY);
|
MmulHelper::matmul(x, y, z, transX, transY, alpha, beta);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -147,11 +151,17 @@ CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
|
||||||
auto dldx = OUTPUT_VARIABLE(0);
|
auto dldx = OUTPUT_VARIABLE(0);
|
||||||
auto dldy = OUTPUT_VARIABLE(1);
|
auto dldy = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
const int iSize = (int) block.getIArguments()->size();
|
int iSize = (int) block.getIArguments()->size();
|
||||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||||
|
|
||||||
|
// optional use alpha nad beta
|
||||||
|
iSize = (int)block.getTArguments()->size();
|
||||||
|
|
||||||
|
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
|
||||||
|
double beta = iSize > 1 ? T_ARG(1) : 0.0;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
In: x=[a,b], y=[b,c]
|
In: x=[a,b], y=[b,c]
|
||||||
tX tY tZ x y z dz dLdx dLdy
|
tX tY tZ x y z dz dLdx dLdy
|
||||||
|
@ -164,8 +174,8 @@ F F T [a,b] [b,c] [c,a] [c,a]
|
||||||
|
|
||||||
|
|
||||||
sd::ops::matmul op;
|
sd::ops::matmul op;
|
||||||
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
|
op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {});
|
||||||
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
|
op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
BROADCASTABLE_OP_IMPL(equals, 0, 0) {
|
BROADCASTABLE_BOOL_OP_IMPL(equals, 0, 0) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
|
@ -14,10 +14,10 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
// modified by sgazeos@gmail.com with backprop implementation.
|
// modified by sgazeos@gmail.com with backprop implementation.
|
||||||
//
|
//
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_floormod)
|
#if NOT_EXCLUDED(OP_floormod)
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ namespace sd {
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
BROADCAST_CHECK_EMPTY(x, y, z);
|
||||||
|
|
||||||
REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!");
|
REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!");
|
||||||
auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z);
|
auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z);
|
||||||
|
@ -46,15 +46,15 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_TYPES(floormod) {
|
DECLARE_TYPES(floormod) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, DataType::ANY)
|
->setAllowedInputTypes(0, DataType::ANY)
|
||||||
->setAllowedInputTypes(1, DataType::ANY)
|
->setAllowedInputTypes(1, DataType::ANY)
|
||||||
->setAllowedOutputTypes(0, DataType::INHERIT);
|
->setAllowedOutputTypes(0, DataType::INHERIT);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(floormod_bp) {
|
DECLARE_TYPES(floormod_bp) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(DataType::ANY)
|
->setAllowedInputTypes(DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ ALL_FLOATS });
|
||||||
}
|
}
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) {
|
CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) {
|
||||||
|
@ -66,11 +66,11 @@ namespace sd {
|
||||||
auto gradY = OUTPUT_VARIABLE(1);
|
auto gradY = OUTPUT_VARIABLE(1);
|
||||||
gradX->assign(epsNext);
|
gradX->assign(epsNext);
|
||||||
|
|
||||||
sd::ops::floormod op;
|
NDArray temp(*epsNext);
|
||||||
auto tmpResult(op.evaluate({x, y}));
|
BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp);
|
||||||
|
|
||||||
if (gradY->rankOf() == gradX->rankOf())
|
if (gradY->rankOf() == gradX->rankOf())
|
||||||
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult.at(0), *gradY);
|
epsNext->applyPairwiseTransform(pairwise::Multiply, temp, *gradY);
|
||||||
else // epsNext is greater than gradY
|
else // epsNext is greater than gradY
|
||||||
{
|
{
|
||||||
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
|
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
|
||||||
|
@ -78,7 +78,7 @@ namespace sd {
|
||||||
for (Nd4jLong d = 0; d < gap; d++) {
|
for (Nd4jLong d = 0; d < gap; d++) {
|
||||||
dims[d * 2 + 1] = 1;
|
dims[d * 2 + 1] = 1;
|
||||||
}
|
}
|
||||||
auto tempIn((*tmpResult.at(0))(dims));
|
auto tempIn((temp)(dims));
|
||||||
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
|
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -92,8 +92,8 @@ namespace sd {
|
||||||
// eps always has shape of x
|
// eps always has shape of x
|
||||||
// grad always has shape of y
|
// grad always has shape of y
|
||||||
|
|
||||||
Nd4jLong *shapeE;
|
Nd4jLong* shapeE;
|
||||||
Nd4jLong *shapeG;
|
Nd4jLong* shapeG;
|
||||||
|
|
||||||
COPY_SHAPE(x, shapeE);
|
COPY_SHAPE(x, shapeE);
|
||||||
COPY_SHAPE(y, shapeG);
|
COPY_SHAPE(y, shapeG);
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
BROADCASTABLE_OP_IMPL(greater, 0, 0) {
|
BROADCASTABLE_BOOL_OP_IMPL(greater, 0, 0) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
BROADCASTABLE_OP_IMPL(greater_equal, 0, 0) {
|
BROADCASTABLE_BOOL_OP_IMPL(greater_equal, 0, 0) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
BROADCASTABLE_OP_IMPL(less, 0, 0) {
|
BROADCASTABLE_BOOL_OP_IMPL(less, 0, 0) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
BROADCASTABLE_OP_IMPL(less_equal, 0, 0) {
|
BROADCASTABLE_BOOL_OP_IMPL(less_equal, 0, 0) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
BROADCASTABLE_OP_IMPL(not_equals, 0, 0) {
|
BROADCASTABLE_BOOL_OP_IMPL(not_equals, 0, 0) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_split_string)
|
#if NOT_EXCLUDED(OP_split_string)
|
||||||
|
@ -60,7 +60,7 @@ namespace sd {
|
||||||
|
|
||||||
// filling output indices
|
// filling output indices
|
||||||
for (uint64_t f = 0; f < cnt; f++) {
|
for (uint64_t f = 0; f < cnt; f++) {
|
||||||
for (auto v: icoords)
|
for (auto v : icoords)
|
||||||
indices->p(ic++, v);
|
indices->p(ic++, v);
|
||||||
|
|
||||||
// last index
|
// last index
|
||||||
|
@ -75,12 +75,12 @@ namespace sd {
|
||||||
for (auto e = 0L; e < input->lengthOf(); e++) {
|
for (auto e = 0L; e < input->lengthOf(); e++) {
|
||||||
auto split = StringUtils::split(input->e<std::string>(e), d);
|
auto split = StringUtils::split(input->e<std::string>(e), d);
|
||||||
|
|
||||||
for (const auto &s:split)
|
for (const auto& s : split)
|
||||||
strings.emplace_back(s);
|
strings.emplace_back(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
// now once we have all strings in single vector time to fill
|
// now once we have all strings in single vector time to fill
|
||||||
auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings);
|
auto tmp = NDArrayFactory::string({ (Nd4jLong)strings.size() }, strings, input->dataType(), block.launchContext());
|
||||||
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
|
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
|
||||||
|
|
||||||
// for CUDA mostly
|
// for CUDA mostly
|
||||||
|
@ -129,9 +129,9 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_TYPES(compat_string_split) {
|
DECLARE_TYPES(compat_string_split) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes({ALL_STRINGS})
|
->setAllowedInputTypes({ ALL_STRINGS })
|
||||||
->setAllowedOutputTypes(0, {ALL_INDICES})
|
->setAllowedOutputTypes(0, { ALL_INDICES })
|
||||||
->setAllowedOutputTypes(1, {ALL_STRINGS});
|
->setAllowedOutputTypes(1, { ALL_STRINGS });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,6 +68,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = *weights * E.lengthOf();
|
sum = *weights * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -201,6 +202,7 @@ namespace sd {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -73,6 +73,7 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = *weights * E.lengthOf();
|
sum = *weights * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -216,6 +217,7 @@ DECLARE_SHAPE_FN(huber_loss) {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -70,6 +70,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = *weights * E.lengthOf();
|
sum = *weights * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -206,6 +207,7 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -74,6 +74,7 @@ namespace ops {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = *weights * E.lengthOf();
|
sum = *weights * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -209,6 +210,7 @@ namespace ops {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -143,6 +143,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -282,6 +283,7 @@ namespace sd {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -67,6 +67,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -200,6 +201,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -78,6 +78,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -219,6 +220,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -24,10 +24,10 @@
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include<ops/declarable/helpers/lstmLayer.h>
|
#include<ops/declarable/helpers/lstmLayer.h>
|
||||||
|
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
||||||
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||||||
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||||
// ct = ft ◦ ct-1 + it ◦ c't
|
// ct = clip(ft ◦ ct-1 + it ◦ c't)
|
||||||
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||||||
// ht = ot ◦ tanh(ct)
|
// ht = ot ◦ tanh(ct)
|
||||||
|
|
||||||
|
@ -72,26 +72,26 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
||||||
// 2) [2, nOut, 4*nOut] when directionMode >= 2
|
// 2) [2, nOut, 4*nOut] when directionMode >= 2
|
||||||
|
|
||||||
// *******
|
// *******
|
||||||
// peephole weights Wp:
|
// peephole weights Wp, optional:
|
||||||
// 1) [3*nOut] when directionMode < 2
|
// 1) [3*nOut] when directionMode < 2
|
||||||
// 2) [2, 3*nOut] when directionMode >= 2
|
// 2) [2, 3*nOut] when directionMode >= 2
|
||||||
|
|
||||||
// *******
|
// *******
|
||||||
// biases b:
|
// biases b, optional:
|
||||||
// 1) [4*nOut] when directionMode < 2
|
// 1) [4*nOut] when directionMode < 2
|
||||||
// 2) [2, 4*nOut] when directionMode >= 2
|
// 2) [2, 4*nOut] when directionMode >= 2
|
||||||
|
|
||||||
// *******
|
// *******
|
||||||
// sequence length array seqLen:
|
// sequence length array seqLen, optional:
|
||||||
// 1) [bS] always
|
// 1) [bS]
|
||||||
|
|
||||||
// *******
|
// *******
|
||||||
// initial output hI:
|
// initial output hI, optional:
|
||||||
// 1) [bS, nOut] when directionMode < 2
|
// 1) [bS, nOut] when directionMode < 2
|
||||||
// 2) [2, bS, nOut] when directionMode >= 2
|
// 2) [2, bS, nOut] when directionMode >= 2
|
||||||
|
|
||||||
// *******
|
// *******
|
||||||
// initial cell state cI (same shape as in hI):
|
// initial cell state cI (same shape as in hI), optional:
|
||||||
// 1) [bS, nOut] when directionMode < 2
|
// 1) [bS, nOut] when directionMode < 2
|
||||||
// 2) [2, bS, nOut] when directionMode >= 2
|
// 2) [2, bS, nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
||||||
// OUTPUTS:
|
// OUTPUTS:
|
||||||
|
|
||||||
// *******
|
// *******
|
||||||
// output h:
|
// output h, optional:
|
||||||
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
|
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
|
||||||
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
|
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
|
||||||
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
|
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
|
||||||
|
@ -109,19 +109,19 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
||||||
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
|
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
|
||||||
|
|
||||||
// *******
|
// *******
|
||||||
// output at last step hL:
|
// output at last step hL, optional:
|
||||||
// 1) [bS, nOut] when directionMode < 2
|
// 1) [bS, nOut] when directionMode < 2
|
||||||
// 2) [2, bS, nOut] when directionMode >= 2
|
// 2) [2, bS, nOut] when directionMode >= 2
|
||||||
|
|
||||||
// *******
|
// *******
|
||||||
// cell state at last step cL (same shape as in hL):
|
// cell state at last step cL (same shape as in hL), optional:
|
||||||
// 1) [bS, nOut] when directionMode < 2
|
// 1) [bS, nOut] when directionMode < 2
|
||||||
// 2) [2, bS, nOut] when directionMode >= 2
|
// 2) [2, bS, nOut] when directionMode >= 2
|
||||||
|
|
||||||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||||
// !!! dimension 3*nOut implies order it, ft, ot
|
// !!! dimension 3*nOut implies order it, ft, ot
|
||||||
|
|
||||||
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
|
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX)
|
||||||
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
|
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
|
||||||
|
|
||||||
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||||
|
@ -135,8 +135,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
||||||
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||||
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||||
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||||
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only
|
||||||
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only
|
||||||
|
|
||||||
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||||||
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||||||
|
@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
||||||
|
|
||||||
// evaluate dimensions
|
// evaluate dimensions
|
||||||
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
|
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
|
||||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
|
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
|
||||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||||
|
|
||||||
// inputs validations
|
// inputs validations
|
||||||
|
@ -323,9 +323,9 @@ DECLARE_SHAPE_FN(lstmLayer) {
|
||||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||||
|
|
||||||
// evaluate dimensions
|
// evaluate dimensions
|
||||||
const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) );
|
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
|
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
|
||||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
|
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
|
||||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||||
|
|
||||||
DataType type;
|
DataType type;
|
||||||
|
@ -398,6 +398,412 @@ DECLARE_SHAPE_FN(lstmLayer) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) {
|
||||||
|
|
||||||
|
// equations (no peephole connections)
|
||||||
|
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||||||
|
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||||||
|
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||||
|
// ct = ft ◦ ct-1 + it ◦ c't
|
||||||
|
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||||||
|
// ht = ot ◦ tanh(ct)
|
||||||
|
|
||||||
|
// equations (peephole connections are present)
|
||||||
|
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||||||
|
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||||||
|
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||||
|
// ct = clip(ft ◦ ct-1 + it ◦ c't)
|
||||||
|
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||||||
|
// ht = ot ◦ tanh(ct)
|
||||||
|
|
||||||
|
// notations:
|
||||||
|
// bS - batch size
|
||||||
|
// sL - sequence length, number of time steps
|
||||||
|
// nIn - input size
|
||||||
|
// nOut - output size (hidden size)
|
||||||
|
|
||||||
|
// INPUTS:
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// input x:
|
||||||
|
// 1) [sL, bS, nIn] when dataFormat == 0
|
||||||
|
// 2) [bS, sL, nIn] when dataFormat == 1
|
||||||
|
// 3) [bS, nIn, sL] when dataFormat == 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// input weights Wx:
|
||||||
|
// 1) [nIn, 4*nOut] when directionMode < 2
|
||||||
|
// 2) [2, nIn, 4*nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// recurrent weights Wr:
|
||||||
|
// 1) [nOut, 4*nOut] when directionMode < 2
|
||||||
|
// 2) [2, nOut, 4*nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// peephole weights Wp, optional:
|
||||||
|
// 1) [3*nOut] when directionMode < 2
|
||||||
|
// 2) [2, 3*nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// biases b, optional:
|
||||||
|
// 1) [4*nOut] when directionMode < 2
|
||||||
|
// 2) [2, 4*nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// sequence length array seqLen, optional:
|
||||||
|
// 1) [bS]
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// initial output hI, optional:
|
||||||
|
// 1) [bS, nOut] when directionMode < 2
|
||||||
|
// 2) [2, bS, nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// initial cell state cI (same shape as in hI), optional:
|
||||||
|
// 1) [bS, nOut] when directionMode < 2
|
||||||
|
// 2) [2, bS, nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// gradient vs. output dLdh, optional:
|
||||||
|
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
|
||||||
|
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
|
||||||
|
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
|
||||||
|
// 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
|
||||||
|
// 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
|
||||||
|
// 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
|
||||||
|
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// gradient vs output at last time step dLdhL, optional:
|
||||||
|
// 1) [bS, nOut] when directionMode < 2
|
||||||
|
// 2) [2, bS, nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// gradient vs cell state at last time step dLdcL(same shape as in dLdhL), optional:
|
||||||
|
// 1) [bS, nOut] when directionMode < 2
|
||||||
|
// 2) [2, bS, nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
|
||||||
|
// OUTPUTS:
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// gradient vs. input dLdx:
|
||||||
|
// 1) [sL, bS, nIn] when dataFormat == 0
|
||||||
|
// 2) [bS, sL, nIn] when dataFormat == 1
|
||||||
|
// 3) [bS, nIn, sL] when dataFormat == 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// gradient vs. input weights dLdWx:
|
||||||
|
// 1) [nIn, 4*nOut] when directionMode < 2
|
||||||
|
// 2) [2, nIn, 4*nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// gradient vs. recurrent weights dLdWr:
|
||||||
|
// 1) [nOut, 4*nOut] when directionMode < 2
|
||||||
|
// 2) [2, nOut, 4*nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// gradient vs. peephole weights dLdWp, optional:
|
||||||
|
// 1) [3*nOut] when directionMode < 2
|
||||||
|
// 2) [2, 3*nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// gradient vs. biases dLdb, optional:
|
||||||
|
// 1) [4*nOut] when directionMode < 2
|
||||||
|
// 2) [2, 4*nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// gradient vs. sequence length array dLdsL, optional (do not calculate it!!!):
|
||||||
|
// 1) [bS] always
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// gradient vs. initial output dLdhI, optional:
|
||||||
|
// 1) [bS, nOut] when directionMode < 2
|
||||||
|
// 2) [2, bS, nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
// *******
|
||||||
|
// gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional:
|
||||||
|
// 1) [bS, nOut] when directionMode < 2
|
||||||
|
// 2) [2, bS, nOut] when directionMode >= 2
|
||||||
|
|
||||||
|
|
||||||
|
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||||
|
// !!! dimension 3*nOut implies order it, ft, ot
|
||||||
|
|
||||||
|
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX)
|
||||||
|
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
|
||||||
|
|
||||||
|
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||||
|
const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates
|
||||||
|
const auto cellAct = INT_ARG(3); // activation for cell state (c)
|
||||||
|
const auto outAct = INT_ARG(4); // activation for output (h)
|
||||||
|
|
||||||
|
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||||
|
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||||
|
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||||
|
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||||
|
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||||
|
const auto retFullSeq = B_ARG(5); // indicates whether gradient vs. outputs is given for whole time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1}
|
||||||
|
const auto retLastH = B_ARG(6); // indicates whether gradient vs. output at last time step (dLdhL) is given
|
||||||
|
const auto retLastC = B_ARG(7); // indicates whether gradient vs. cell state at last time step (dLdcL) is given
|
||||||
|
|
||||||
|
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||||||
|
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||||||
|
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
|
||||||
|
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
|
||||||
|
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
|
||||||
|
const auto outActHasBeta = outAct == 3 || outAct == 6;
|
||||||
|
|
||||||
|
uint count = 1;
|
||||||
|
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||||||
|
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
|
||||||
|
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
|
||||||
|
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
|
||||||
|
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
|
||||||
|
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
|
||||||
|
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode);
|
||||||
|
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_BP operation: cell clipping value should be nonnegative (>=0) !");
|
||||||
|
REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER_BP operation: please specify at least one of three input gradient arrays: dLdh, dLdhL or dLdcL !");
|
||||||
|
|
||||||
|
const auto x = INPUT_VARIABLE(0); // input
|
||||||
|
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||||
|
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||||
|
|
||||||
|
count = 3;
|
||||||
|
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||||
|
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
|
||||||
|
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
|
||||||
|
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
|
||||||
|
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
|
||||||
|
const auto dLdh = retFullSeq ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output
|
||||||
|
const auto dLdhL = retLastH ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output at last time step
|
||||||
|
const auto dLdcL = retLastC ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. cell state at last time step
|
||||||
|
|
||||||
|
count = 3;
|
||||||
|
auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. input
|
||||||
|
auto dLdWx = OUTPUT_NULLIFIED(1); // gradient vs. input weights
|
||||||
|
auto dLdWr = OUTPUT_NULLIFIED(2); // gradient vs. recurrent weights
|
||||||
|
auto dLdb = hasBiases ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. biases
|
||||||
|
auto dLdsL = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. seqLen vector, we don't calculate it !!!
|
||||||
|
auto dLdhI = hasInitH ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial output
|
||||||
|
auto dLdcI = hasInitC ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial cell state
|
||||||
|
auto dLdWp = hasPH ? OUTPUT_NULLIFIED(count) : nullptr; // gradient vs. peephole weights
|
||||||
|
|
||||||
|
// evaluate dimensions
|
||||||
|
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||||||
|
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
|
||||||
|
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
|
||||||
|
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||||
|
|
||||||
|
// inputs validations
|
||||||
|
if(directionMode < 2) { // no bidirectional
|
||||||
|
|
||||||
|
// Wx validation
|
||||||
|
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||||
|
// Wr validation
|
||||||
|
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||||
|
// biases validation
|
||||||
|
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||||
|
// initial output validation
|
||||||
|
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||||
|
// initial cell validation
|
||||||
|
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||||
|
// peephole weights validation
|
||||||
|
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||||
|
// gradient vs. output at last time step validation
|
||||||
|
if(dLdhL != nullptr && (dLdhL->rankOf() != 2 || dLdhL->sizeAt(0) != bS || dLdhL->sizeAt(1) != nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str());
|
||||||
|
// gradient vs. cell state at last time step validation
|
||||||
|
if(dLdcL != nullptr && (dLdcL->rankOf() != 2 || dLdcL->sizeAt(0) != bS || dLdcL->sizeAt(1) != nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str());
|
||||||
|
}
|
||||||
|
else { // bidirectional
|
||||||
|
// Wx validation
|
||||||
|
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||||
|
// Wr validation
|
||||||
|
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||||
|
// biases validation
|
||||||
|
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||||
|
// initial output validation
|
||||||
|
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||||
|
// initial cell validation
|
||||||
|
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||||
|
// peephole weights validation
|
||||||
|
if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||||
|
// gradient vs. output at last time step validation
|
||||||
|
if(dLdhL != nullptr && (dLdhL->rankOf() != 3 || dLdhL->sizeAt(0) != 2 || dLdhL->sizeAt(1) != bS || dLdhL->sizeAt(2) != nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str());
|
||||||
|
// gradient vs. cell state at last time step validation
|
||||||
|
if(dLdcL != nullptr && (dLdcL->rankOf() != 3 || dLdcL->sizeAt(0) != 2 || dLdcL->sizeAt(1) != bS || dLdcL->sizeAt(2) != nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradient vs. output validation
|
||||||
|
if(dLdh) {
|
||||||
|
int factor = directionMode <= 2 ? 1 : 2;
|
||||||
|
std::vector<Nd4jLong> expdLdhShape;
|
||||||
|
if(dataFormat == 0) expdLdhShape = std::vector<Nd4jLong>{sL, bS, factor*nOut};
|
||||||
|
else if(dataFormat == 1) expdLdhShape = std::vector<Nd4jLong>{bS, sL, factor*nOut};
|
||||||
|
else if(dataFormat == 2) expdLdhShape = std::vector<Nd4jLong>{bS, factor*nOut, sL};
|
||||||
|
else expdLdhShape = std::vector<Nd4jLong>{sL, 2, bS, nOut};
|
||||||
|
REQUIRE_TRUE(dLdh->isSameShape(expdLdhShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expdLdhShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip),
|
||||||
|
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
|
||||||
|
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
|
||||||
|
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
|
||||||
|
|
||||||
|
if(directionMode == 0) { // forward
|
||||||
|
|
||||||
|
helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, true, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp);
|
||||||
|
}
|
||||||
|
else if(directionMode == 1) { // backward
|
||||||
|
|
||||||
|
helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, false, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp);
|
||||||
|
}
|
||||||
|
else { // bidirectional
|
||||||
|
|
||||||
|
NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0});
|
||||||
|
NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0});
|
||||||
|
NDArray dLdWxFwd = (*dLdWx)({0,1, 0,0, 0,0});
|
||||||
|
NDArray dLdWxBwd = (*dLdWx)({1,2, 0,0, 0,0});
|
||||||
|
|
||||||
|
NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0});
|
||||||
|
NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0});
|
||||||
|
NDArray dLdWrFwd = (*dLdWr)({0,1, 0,0, 0,0});
|
||||||
|
NDArray dLdWrBwd = (*dLdWr)({1,2, 0,0, 0,0});
|
||||||
|
|
||||||
|
NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr),
|
||||||
|
*dLdhFwd(nullptr), *dLdhBwd(nullptr), *dLdhLFwd(nullptr), *dLdhLBwd(nullptr), *dLdcLFwd(nullptr), *dLdcLBwd(nullptr),
|
||||||
|
*dLdWpFwd(nullptr), *dLdWpBwd(nullptr), *dLdbFwd(nullptr), *dLdbBwd(nullptr),
|
||||||
|
*dLdhIFwd(nullptr), *dLdhIBwd(nullptr), *dLdcIFwd(nullptr), *dLdcIBwd(nullptr);
|
||||||
|
|
||||||
|
if(Wp) {
|
||||||
|
WpFwd = new NDArray((*Wp)({0,1, 0,0}));
|
||||||
|
WpBwd = new NDArray((*Wp)({1,2, 0,0}));
|
||||||
|
dLdWpFwd = new NDArray((*dLdWp)({0,1, 0,0}));
|
||||||
|
dLdWpBwd = new NDArray((*dLdWp)({1,2, 0,0}));
|
||||||
|
}
|
||||||
|
if(b) {
|
||||||
|
bFwd = new NDArray((*b)({0,1, 0,0}));
|
||||||
|
bBwd = new NDArray((*b)({1,2, 0,0}));
|
||||||
|
dLdbFwd = new NDArray((*dLdb)({0,1, 0,0}));
|
||||||
|
dLdbBwd = new NDArray((*dLdb)({1,2, 0,0}));
|
||||||
|
}
|
||||||
|
if(hI) {
|
||||||
|
hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0}));
|
||||||
|
hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0}));
|
||||||
|
dLdhIFwd = new NDArray((*dLdhI)({0,1, 0,0, 0,0}));
|
||||||
|
dLdhIBwd = new NDArray((*dLdhI)({1,2, 0,0, 0,0}));
|
||||||
|
}
|
||||||
|
if(cI) {
|
||||||
|
cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0}));
|
||||||
|
cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0}));
|
||||||
|
dLdcIFwd = new NDArray((*dLdcI)({0,1, 0,0, 0,0}));
|
||||||
|
dLdcIBwd = new NDArray((*dLdcI)({1,2, 0,0, 0,0}));
|
||||||
|
}
|
||||||
|
if(dLdhL) {
|
||||||
|
dLdhLFwd = new NDArray((*dLdhL)({0,1, 0,0, 0,0}));
|
||||||
|
dLdhLBwd = new NDArray((*dLdhL)({1,2, 0,0, 0,0}));
|
||||||
|
}
|
||||||
|
if(dLdcL) {
|
||||||
|
dLdcLFwd = new NDArray((*dLdcL)({0,1, 0,0, 0,0}));
|
||||||
|
dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME looks like sum (directionMode == 2) is impossible for backprop
|
||||||
|
if(dLdh) {
|
||||||
|
if(directionMode == 2) { // sum
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: mode for bidirectional sum and dLdh being present has no sense for backpropagation !");
|
||||||
|
// dLdhFwd = dLdh;
|
||||||
|
// dLdhBwd = new NDArray(dLdh->ordering(), dLdh->getShapeAsVector(), dLdh->dataType(), dLdh->getContext()); // automatically nullifies content
|
||||||
|
}
|
||||||
|
else if(directionMode == 3) { // concat
|
||||||
|
dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0}));
|
||||||
|
dLdhBwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, nOut,2*nOut}) : (*dLdh)({0,0, nOut,2*nOut, 0,0}));
|
||||||
|
}
|
||||||
|
else { // directionMode == 4
|
||||||
|
dLdhFwd = new NDArray((*dLdh)({0,0, 0,1, 0,0, 0,0}));
|
||||||
|
dLdhBwd = new NDArray((*dLdh)({0,0, 1,2, 0,0, 0,0}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd);
|
||||||
|
NDArray dLdxBwd = dLdx->ulike();
|
||||||
|
helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd);
|
||||||
|
|
||||||
|
*dLdx += dLdxBwd;
|
||||||
|
|
||||||
|
delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd;
|
||||||
|
delete dLdhBwd; delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd;
|
||||||
|
delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd;
|
||||||
|
delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd;
|
||||||
|
|
||||||
|
if(dLdhFwd != dLdh)
|
||||||
|
delete dLdhFwd;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(lstmLayer_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(lstmLayer_bp) {
|
||||||
|
|
||||||
|
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||||
|
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||||
|
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||||
|
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||||
|
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||||
|
|
||||||
|
int count = 3;
|
||||||
|
const auto x = INPUT_VARIABLE(0); // input
|
||||||
|
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||||
|
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||||
|
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||||
|
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
|
||||||
|
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
|
||||||
|
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
|
||||||
|
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
|
||||||
|
|
||||||
|
std::vector<Nd4jLong*> outShapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()};
|
||||||
|
|
||||||
|
if(b != nullptr)
|
||||||
|
outShapes.push_back(b->getShapeInfo());
|
||||||
|
if(seqLen != nullptr)
|
||||||
|
outShapes.push_back(seqLen->getShapeInfo());
|
||||||
|
if(hI != nullptr)
|
||||||
|
outShapes.push_back(hI->getShapeInfo());
|
||||||
|
if(cI != nullptr)
|
||||||
|
outShapes.push_back(cI->getShapeInfo());
|
||||||
|
if(Wp != nullptr)
|
||||||
|
outShapes.push_back(Wp->getShapeInfo());
|
||||||
|
|
||||||
|
return new ShapeList(outShapes);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,339 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_lstmLayerCell)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include<ops/declarable/helpers/lstmLayer.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) {
|
||||||
|
|
||||||
|
// equations (no peephole connections)
|
||||||
|
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||||||
|
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||||||
|
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||||
|
// ct = ft ◦ ct-1 + it ◦ c't
|
||||||
|
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||||||
|
// ht = ot ◦ tanh(ct)
|
||||||
|
|
||||||
|
// equations (peephole connections are present)
|
||||||
|
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||||||
|
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||||||
|
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||||
|
// ct = clip(ft ◦ ct-1 + it ◦ c't)
|
||||||
|
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||||||
|
// ht = ot ◦ tanh(ct)
|
||||||
|
|
||||||
|
// notations:
|
||||||
|
// bS - batch size
|
||||||
|
// nIn - input size
|
||||||
|
// nOut - output size (hidden size)
|
||||||
|
|
||||||
|
// INPUTS:
|
||||||
|
// input x: [bS, nIn] or [nIn]
|
||||||
|
// input weights Wx: [nIn, 4*nOut]
|
||||||
|
// recurrent weights Wr: [nOut, 4*nOut]
|
||||||
|
// initial (previous) output hI: [bS, nOut] or [nOut]
|
||||||
|
// initial (previous) cell state cI: [bS, nOut] or [nOut]
|
||||||
|
// biases b (optional): [4*nOut]
|
||||||
|
// peephole weights Wp (optional): [3*nOut]
|
||||||
|
|
||||||
|
// OUTPUTS:
|
||||||
|
// current output h: [bS, nOut] or [nOut]
|
||||||
|
// current cell state c: [bS, nOut] or [nOut]
|
||||||
|
|
||||||
|
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||||
|
// !!! dimension 3*nOut implies order it, ft, ot
|
||||||
|
|
||||||
|
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||||
|
const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates
|
||||||
|
const auto cellAct = INT_ARG(1); // activation for cell state (c)
|
||||||
|
const auto outAct = INT_ARG(2); // activation for output (h)
|
||||||
|
|
||||||
|
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||||
|
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
|
||||||
|
|
||||||
|
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||||||
|
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||||||
|
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
|
||||||
|
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
|
||||||
|
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
|
||||||
|
const auto outActHasBeta = outAct == 3 || outAct == 6;
|
||||||
|
|
||||||
|
uint count = 1;
|
||||||
|
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||||||
|
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
|
||||||
|
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
|
||||||
|
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
|
||||||
|
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
|
||||||
|
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
|
||||||
|
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
|
||||||
|
|
||||||
|
count = 3;
|
||||||
|
const auto x = INPUT_VARIABLE(0); // input
|
||||||
|
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||||
|
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||||
|
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||||
|
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||||||
|
const auto cI = INPUT_VARIABLE(count++); // initial cell state
|
||||||
|
const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights
|
||||||
|
|
||||||
|
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !");
|
||||||
|
|
||||||
|
auto h = OUTPUT_VARIABLE(0);
|
||||||
|
auto c = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
// evaluate dimensions
|
||||||
|
const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
|
||||||
|
const Nd4jLong nIn = x->sizeAt(-1);
|
||||||
|
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||||
|
|
||||||
|
// inputs validations
|
||||||
|
// Wx validation
|
||||||
|
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||||
|
// Wr validation
|
||||||
|
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||||
|
// initial output/cell validation
|
||||||
|
std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
|
||||||
|
REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||||
|
REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||||
|
// biases validation
|
||||||
|
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||||
|
// peephole weights validation
|
||||||
|
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||||
|
|
||||||
|
std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
|
||||||
|
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
|
||||||
|
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
|
||||||
|
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
|
||||||
|
|
||||||
|
helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(lstmLayerCell) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(lstmLayerCell) {
|
||||||
|
|
||||||
|
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||||
|
|
||||||
|
uint count = hasBiases ? 4 : 3;
|
||||||
|
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||||||
|
const auto cI = INPUT_VARIABLE(count); // initial cell state
|
||||||
|
|
||||||
|
return new ShapeList({hI->getShapeInfo(), cI->getShapeInfo()});
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) {
|
||||||
|
|
||||||
|
// equations (no peephole connections)
|
||||||
|
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||||||
|
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||||||
|
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||||
|
// ct = ft ◦ ct-1 + it ◦ c't
|
||||||
|
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||||||
|
// ht = ot ◦ tanh(ct)
|
||||||
|
|
||||||
|
// equations (peephole connections are present)
|
||||||
|
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||||||
|
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||||||
|
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||||
|
// ct = clip(ft ◦ ct-1 + it ◦ c't)
|
||||||
|
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||||||
|
// ht = ot ◦ tanh(ct)
|
||||||
|
|
||||||
|
// notations:
|
||||||
|
// bS - batch size
|
||||||
|
// nIn - input size
|
||||||
|
// nOut - output size (hidden size)
|
||||||
|
|
||||||
|
// INPUTS:
|
||||||
|
// input x: [bS, nIn] or [nIn]
|
||||||
|
// input weights Wx: [nIn, 4*nOut]
|
||||||
|
// recurrent weights Wr: [nOut, 4*nOut]
|
||||||
|
// initial (previous) output hI: [bS, nOut] or [nOut]
|
||||||
|
// initial (previous) cell state cI: [bS, nOut] or [nOut]
|
||||||
|
// gradient wrt output dLdh: [bS, nOut] or [nOut]
|
||||||
|
// gradient wrt cell state dLdc: [bS, nOut] or [nOut]
|
||||||
|
// peephole weights Wp (optional): [3*nOut]
|
||||||
|
// biases b (optional): [4*nOut]
|
||||||
|
|
||||||
|
// OUTPUTS:
|
||||||
|
// gradient wrt x dLdx: [bS, nIn] or [nIn]
|
||||||
|
// gradient wrt Wx dLdWx: [nIn, 4*nOut]
|
||||||
|
// gradient wrt Wr dLdWr: [nOut, 4*nOut]
|
||||||
|
// gradient wrt hI dLdhI: [bS, nOut] or [nOut]
|
||||||
|
// gradient wrt cI dLdcI: [bS, nOut] or [nOut]
|
||||||
|
// gradient wrt b dLdb (optional): [4*nOut]
|
||||||
|
// gradient wrt Wp dLdWp (optional): [3*nOut]
|
||||||
|
|
||||||
|
|
||||||
|
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||||
|
// !!! dimension 3*nOut implies order it, ft, ot
|
||||||
|
|
||||||
|
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||||
|
const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates
|
||||||
|
const auto cellAct = INT_ARG(1); // activation for cell state (c)
|
||||||
|
const auto outAct = INT_ARG(2); // activation for output (h)
|
||||||
|
|
||||||
|
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||||
|
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
|
||||||
|
|
||||||
|
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||||||
|
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||||||
|
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
|
||||||
|
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
|
||||||
|
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
|
||||||
|
const auto outActHasBeta = outAct == 3 || outAct == 6;
|
||||||
|
|
||||||
|
uint count = 1;
|
||||||
|
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||||||
|
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
|
||||||
|
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
|
||||||
|
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
|
||||||
|
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
|
||||||
|
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
|
||||||
|
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
|
||||||
|
|
||||||
|
count = 3;
|
||||||
|
const auto x = INPUT_VARIABLE(0); // input
|
||||||
|
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||||
|
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||||
|
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||||
|
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||||||
|
const auto cI = INPUT_VARIABLE(count++); // initial cell state
|
||||||
|
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
|
||||||
|
const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output
|
||||||
|
|
||||||
|
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !");
|
||||||
|
|
||||||
|
count = 3;
|
||||||
|
auto dLdx = OUTPUT_VARIABLE(0);
|
||||||
|
auto dLdWx = OUTPUT_VARIABLE(1);
|
||||||
|
auto dLdWr = OUTPUT_VARIABLE(2);
|
||||||
|
auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr;
|
||||||
|
auto dLdhI = OUTPUT_VARIABLE(count++);
|
||||||
|
auto dLdcI = OUTPUT_VARIABLE(count++);
|
||||||
|
auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr;
|
||||||
|
|
||||||
|
// evaluate dimensions
|
||||||
|
const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
|
||||||
|
const Nd4jLong nIn = x->sizeAt(-1);
|
||||||
|
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||||
|
|
||||||
|
// inputs validations
|
||||||
|
// Wx validation
|
||||||
|
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||||
|
// Wr validation
|
||||||
|
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||||
|
// initial output/cell validation
|
||||||
|
std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
|
||||||
|
REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||||
|
REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||||
|
REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
|
||||||
|
// biases validation
|
||||||
|
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||||
|
if(dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4*nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str());
|
||||||
|
// peephole weights validation
|
||||||
|
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||||
|
if(dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3*nOut))
|
||||||
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str());
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
|
||||||
|
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
|
||||||
|
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
|
||||||
|
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> zShape = x->rankOf() == 1 ? std::vector<Nd4jLong>({4*nOut}) : std::vector<Nd4jLong>({bS, 4*nOut});
|
||||||
|
|
||||||
|
NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext());
|
||||||
|
NDArray a = z.ulike();
|
||||||
|
NDArray h = cI->ulike();
|
||||||
|
NDArray c = cI->ulike();
|
||||||
|
|
||||||
|
helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c);
|
||||||
|
|
||||||
|
helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(lstmLayerCellBp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(lstmLayerCellBp) {
|
||||||
|
|
||||||
|
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||||
|
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
|
||||||
|
|
||||||
|
uint count = 3;
|
||||||
|
const auto x = INPUT_VARIABLE(0); // input
|
||||||
|
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||||
|
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||||
|
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||||
|
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||||||
|
const auto cI = INPUT_VARIABLE(count++); // initial cell state
|
||||||
|
const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights
|
||||||
|
|
||||||
|
std::vector<Nd4jLong*> shapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()};
|
||||||
|
|
||||||
|
if(b != nullptr)
|
||||||
|
shapes.push_back(b->getShapeInfo());
|
||||||
|
|
||||||
|
shapes.push_back(hI->getShapeInfo());
|
||||||
|
shapes.push_back(cI->getShapeInfo());
|
||||||
|
|
||||||
|
if(Wp != nullptr)
|
||||||
|
shapes.push_back(Wp->getShapeInfo());
|
||||||
|
|
||||||
|
return new ShapeList(shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -14,10 +14,11 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// xw_plus_b op. Created by GS <george@skymind.io> 31.01.2018
|
// xw_plus_b op. Created by GS <george@skymind.io> 31.01.2018
|
||||||
//
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
//
|
//
|
||||||
|
//
|
||||||
|
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_xw_plus_b)
|
#if NOT_EXCLUDED(OP_xw_plus_b)
|
||||||
|
@ -29,36 +30,115 @@
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
|
||||||
auto b = INPUT_VARIABLE(2);
|
auto b = INPUT_VARIABLE(2);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf() <= 2 && y->rankOf() <= 2 && z->rankOf() <= 2, 0, "xw_plus_b: Input and Output NDArrays should have rank less or equal to 2");
|
if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty())
|
||||||
REQUIRE_TRUE(b->isVector() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input vector should have proper dimension 1x%i. "
|
return Status::OK();
|
||||||
"But %i != %i.", z->sizeAt(-1), b->lengthOf(), z->sizeAt(-1));
|
|
||||||
|
const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false);
|
||||||
|
|
||||||
|
auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
|
||||||
|
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
|
||||||
|
REQUIRE_TRUE(z->rankOf() == 2, 0, "xw_plus_b: Output array should have rank equal 2, but got instead %i!", z->rankOf());
|
||||||
|
|
||||||
|
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input bias vector should be 1D and have proper dimension 1x%i."
|
||||||
|
" But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1));
|
||||||
|
|
||||||
// multiply x to y
|
// multiply x to y
|
||||||
MmulHelper::mmul(x, y, z, 1.0, 0.0);
|
MmulHelper::mmul(x, w, z, 1.0, 0.0);
|
||||||
|
|
||||||
// adding b vector
|
// adding b vector
|
||||||
z->addiRowVector(*b);
|
z->addiRowVector(*b);
|
||||||
|
|
||||||
|
if (bTranspose)
|
||||||
|
delete w;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(xw_plus_b) {
|
DECLARE_SHAPE_FN(xw_plus_b) {
|
||||||
auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), inputShape->at(1), false, false,
|
|
||||||
ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace());
|
auto weights = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
const int nWeightsFormat = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
|
||||||
|
|
||||||
|
auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.getWorkspace()) : inputShape->at(1);
|
||||||
|
|
||||||
|
auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), weightsShape, false, false,
|
||||||
|
ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace());
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(outputShape));
|
return SHAPELIST(CONSTANT(outputShape));
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(xw_plus_b) {
|
DECLARE_TYPES(xw_plus_b) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ ALL_FLOATS });
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CUSTOM_OP_IMPL(xw_plus_b_bp, 4, 3, false, 0, 0) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto b = INPUT_VARIABLE(2);
|
||||||
|
auto dLdz = INPUT_VARIABLE(3);
|
||||||
|
|
||||||
|
auto dLdx = OUTPUT_VARIABLE(0);
|
||||||
|
auto dLdb = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || dLdz->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false);
|
||||||
|
|
||||||
|
auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
|
||||||
|
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
|
||||||
|
REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf());
|
||||||
|
REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(-1), 0, "xw_plus_b BP: Input bias vector should be 1D and have proper dimension 1x%i."
|
||||||
|
" But got rank %i, and got length %i instead %i.", dLdz->sizeAt(-1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(-1));
|
||||||
|
|
||||||
|
auto dLdw = (bTranspose) ? new NDArray(OUTPUT_VARIABLE(1)->transpose()) : OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
// dLdb
|
||||||
|
dLdb->assign(dLdz->reduceAlongDimension(reduce::Sum, { 0 }));
|
||||||
|
|
||||||
|
matmul_bp mmul_bp;
|
||||||
|
mmul_bp.execute({ x, w, dLdz }, std::vector<NDArray*>{dLdx, dLdw}, {}, {}, {});
|
||||||
|
|
||||||
|
if (bTranspose) {
|
||||||
|
delete w;
|
||||||
|
delete dLdw;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(xw_plus_b_bp) {
|
||||||
|
|
||||||
|
Nd4jLong* xShapeInfo;
|
||||||
|
Nd4jLong* wShapeInfo;
|
||||||
|
Nd4jLong* bShapeInfo;
|
||||||
|
|
||||||
|
COPY_SHAPE(inputShape->at(0), xShapeInfo);
|
||||||
|
COPY_SHAPE(inputShape->at(1), wShapeInfo);
|
||||||
|
COPY_SHAPE(inputShape->at(2), bShapeInfo);
|
||||||
|
|
||||||
|
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(wShapeInfo), CONSTANT(bShapeInfo));
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(xw_plus_b_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
|
->setAllowedOutputTypes({ ALL_FLOATS });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author sgazeos@gmail.com
|
// @author sgazeos@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||||
#include <ops/declarable/headers/parity_ops.h>
|
#include <ops/declarable/headers/parity_ops.h>
|
||||||
|
@ -29,24 +29,24 @@ namespace sd {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector());
|
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector(), block.launchContext());
|
||||||
BROADCAST_CHECK_EMPTY(x, y, (&z0));
|
BROADCAST_CHECK_EMPTY(x, y, (&z0));
|
||||||
|
|
||||||
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
|
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
|
||||||
bitcast res;
|
bitcast res;
|
||||||
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false);
|
auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false);
|
||||||
if (tZ != &z0) {
|
if (tZ != &z0) {
|
||||||
delete tZ;
|
delete tZ;
|
||||||
}
|
}
|
||||||
|
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(compare_and_bitpack) {
|
DECLARE_TYPES(compare_and_bitpack) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, DataType::ANY)
|
->setAllowedInputTypes(0, DataType::ANY)
|
||||||
->setAllowedInputTypes(1, DataType::ANY)
|
->setAllowedInputTypes(1, DataType::ANY)
|
||||||
->setAllowedOutputTypes(0, DataType::UINT8);
|
->setAllowedOutputTypes(0, DataType::UINT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(compare_and_bitpack) {
|
DECLARE_SHAPE_FN(compare_and_bitpack) {
|
||||||
|
|
|
@ -45,8 +45,8 @@ namespace sd {
|
||||||
weights = INPUT_VARIABLE(2);
|
weights = INPUT_VARIABLE(2);
|
||||||
REQUIRE_TRUE(weights->isSameShape(predictions),0, "CONFUSION_MATRIX: Weights and predictions should have equal shape");
|
REQUIRE_TRUE(weights->isSameShape(predictions),0, "CONFUSION_MATRIX: Weights and predictions should have equal shape");
|
||||||
}
|
}
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
output->assign(0.);
|
|
||||||
int minPrediction = predictions->reduceNumber(reduce::Min).e<int>(0);
|
int minPrediction = predictions->reduceNumber(reduce::Min).e<int>(0);
|
||||||
int minLabel = labels->reduceNumber(reduce::Min).e<int>(0);
|
int minLabel = labels->reduceNumber(reduce::Min).e<int>(0);
|
||||||
|
|
||||||
|
@ -64,11 +64,7 @@ namespace sd {
|
||||||
DECLARE_SHAPE_FN(confusion_matrix) {
|
DECLARE_SHAPE_FN(confusion_matrix) {
|
||||||
auto labels = INPUT_VARIABLE(0);
|
auto labels = INPUT_VARIABLE(0);
|
||||||
auto predictions = INPUT_VARIABLE(1);
|
auto predictions = INPUT_VARIABLE(1);
|
||||||
auto dtype = block.dataType();
|
auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
|
||||||
dtype = sd::DataType::INT64; // dtype - should be a param with int argument
|
|
||||||
if (block.numI() > 1)
|
|
||||||
dtype = (sd::DataType)INT_ARG(1);
|
|
||||||
|
|
||||||
int numClasses = 0;
|
int numClasses = 0;
|
||||||
|
|
||||||
if (block.getIArguments()->size() > 0) {
|
if (block.getIArguments()->size() > 0) {
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by george@skymind.io on 26.01.2018.
|
// Created by george@skymind.io on 26.01.2018.
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_normalize_moments)
|
#if NOT_EXCLUDED(OP_normalize_moments)
|
||||||
|
@ -34,7 +34,7 @@ namespace sd {
|
||||||
auto resVariances = OUTPUT_VARIABLE(1);
|
auto resVariances = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
// FIXME: double?
|
// FIXME: double?
|
||||||
NDArray shift = NDArrayFactory::create<double>(0.);
|
NDArray shift = NDArrayFactory::create<double>(0., block.launchContext());
|
||||||
|
|
||||||
if (block.getTArguments()->size() > 0) {
|
if (block.getTArguments()->size() > 0) {
|
||||||
shift.assign(T_ARG(0));
|
shift.assign(T_ARG(0));
|
||||||
|
@ -47,7 +47,7 @@ namespace sd {
|
||||||
|
|
||||||
squareMeans.applyTransform(transform::Square, squareMeans, nullptr);
|
squareMeans.applyTransform(transform::Square, squareMeans, nullptr);
|
||||||
variances->applyScalarArr(scalar::Divide, *counts, tempVariances);
|
variances->applyScalarArr(scalar::Divide, *counts, tempVariances);
|
||||||
// tempVariances.printIndexedBuffer("varianced divided by count");
|
// tempVariances.printIndexedBuffer("varianced divided by count");
|
||||||
tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances);
|
tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances);
|
||||||
|
|
||||||
if (shift.e<double>(0) != 0) {
|
if (shift.e<double>(0) != 0) {
|
||||||
|
@ -75,8 +75,8 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_TYPES(normalize_moments) {
|
DECLARE_TYPES(normalize_moments) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ ALL_FLOATS });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,8 +49,8 @@ namespace sd {
|
||||||
bool disposable = false;
|
bool disposable = false;
|
||||||
|
|
||||||
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
||||||
min = NDArrayFactory::create_(dtype);
|
min = NDArrayFactory::create_(dtype, block.launchContext());
|
||||||
max = NDArrayFactory::create_(dtype);
|
max = NDArrayFactory::create_(dtype, block.launchContext());
|
||||||
min->p(0, T_ARG(0));
|
min->p(0, T_ARG(0));
|
||||||
max->p(0, T_ARG(1));
|
max->p(0, T_ARG(1));
|
||||||
disposable = true;
|
disposable = true;
|
||||||
|
|
|
@ -35,111 +35,19 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||||
if (x->isEmpty()) {
|
if (x->isEmpty()) {
|
||||||
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||||
return Status::OK(); //No op
|
return Status::OK(); //No op
|
||||||
}
|
|
||||||
|
|
||||||
if (block.width() == 1) {
|
|
||||||
|
|
||||||
auto arguments = block.getIArguments();
|
|
||||||
int argsSize = arguments->size();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int e = 1;
|
|
||||||
char order = (char) -(*arguments)[0];
|
|
||||||
if (order != 'c' && order != 'f') {
|
|
||||||
order = 'c'; //x->ordering();
|
|
||||||
e = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew;
|
|
||||||
int e2 = e;
|
|
||||||
for (; e < (int) arguments->size(); e++) {
|
|
||||||
if (arguments->at(e) == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(; e2 < e; e2++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
|
||||||
shapeNew.push_back(realShape);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew.push_back(arguments->at(e));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
|
||||||
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
|
||||||
|
|
||||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto xr = x->reshape(order, shapeNew);
|
|
||||||
z->assign(xr);
|
|
||||||
STORE_RESULT(*z);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
|
|
||||||
} else if (block.width() == 2) {
|
|
||||||
|
|
||||||
auto s = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
char order = 'c';
|
|
||||||
if (block.numI() > 0)
|
|
||||||
order = (char) -INT_ARG(0);
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew(s->lengthOf());
|
|
||||||
|
|
||||||
for (int e = 0; e < (int) s->lengthOf(); e++) {
|
|
||||||
auto dim = s->e<Nd4jLong >(e);
|
|
||||||
if (dim == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(int e2 = 0; e2 < e; e2++){
|
|
||||||
shapeLength *= s->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
|
||||||
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= s->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
|
||||||
shapeNew[e] = realShape;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew[e] = dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (s->isScalar()) {
|
|
||||||
// just a scalar
|
|
||||||
z->assign(x);
|
|
||||||
} else {
|
|
||||||
// in some cases we might go away with simple memcpy call instead of assign call
|
|
||||||
if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) {
|
|
||||||
z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset());
|
|
||||||
} else {
|
|
||||||
auto xr = x->reshape(order, shapeNew);
|
|
||||||
z->assign(xr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ND4J_STATUS_BAD_INPUT;
|
REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf());
|
||||||
|
|
||||||
|
if (Environment::getInstance()->isDebugAndVerbose())
|
||||||
|
nd4j_printv("Reshape: new shape", z->getShapeAsVector());
|
||||||
|
|
||||||
|
z->assign(x->reshape(z->ordering(), z->getShapeAsVector()));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -151,117 +59,111 @@ DECLARE_TYPES(reshape) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(reshape) {
|
DECLARE_SHAPE_FN(reshape) {
|
||||||
auto inp = inputShape->at(0);
|
|
||||||
|
|
||||||
// we can launch op using Int arguments
|
const auto x = INPUT_VARIABLE(0);
|
||||||
if (inputShape->size() == 1) {
|
|
||||||
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
|
|
||||||
std::vector<int> *arguments = block.getIArguments();
|
|
||||||
|
|
||||||
int e = 1;
|
std::vector<int> reshapeArgs;
|
||||||
char order = (char) -(*arguments)[0];
|
std::vector<Nd4jLong> shapeNew;
|
||||||
if (order != 'c' && order != 'f') {
|
char orderNew = 'c';
|
||||||
order = shape::order(inp);
|
|
||||||
e = 0;
|
if (block.width() == 1) {
|
||||||
|
reshapeArgs = *block.getIArguments();
|
||||||
|
if(!reshapeArgs.empty()) {
|
||||||
|
orderNew = (char) -reshapeArgs[0];
|
||||||
|
if(orderNew == 'c' || orderNew == 'f')
|
||||||
|
reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew;
|
|
||||||
|
|
||||||
int e2 = e;
|
|
||||||
for (; e < (int) arguments->size(); e++) {
|
|
||||||
if ((int) arguments->at(e) == -1){
|
|
||||||
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(; e2 < e; e2 ++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
|
||||||
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(shapeLength == 0){
|
|
||||||
//Edge case for empty:
|
|
||||||
shapeNew.push_back(0);
|
|
||||||
} else {
|
|
||||||
//Standard case
|
|
||||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
|
||||||
shapeNew.push_back(realShape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew.push_back(arguments->at(e));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
|
|
||||||
} else {
|
|
||||||
// or, with second input "as shape"
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
|
||||||
auto y = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
// special case here
|
|
||||||
if (y->isEmpty()) {
|
|
||||||
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
|
|
||||||
}
|
|
||||||
//Special case: empty.reshape(-1) -> return empty
|
|
||||||
if (x->isEmpty()) {
|
|
||||||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
|
||||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
|
||||||
Nd4jLong prod = 1;
|
|
||||||
bool hasNegs = false;
|
|
||||||
for (auto v:shapeOf) {
|
|
||||||
if (v < 0) {
|
|
||||||
hasNegs = true;
|
|
||||||
v = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
prod *= v;
|
|
||||||
}
|
|
||||||
|
|
||||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
|
||||||
|
|
||||||
// if there are -1s - we turn them into zeros
|
|
||||||
if (hasNegs) {
|
|
||||||
for (int e = 0; e < shapeOf.size(); e++)
|
|
||||||
if (shapeOf[e] < 0)
|
|
||||||
shapeOf[e] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
|
||||||
return SHAPELIST(CONSTANT(newShape));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
|
||||||
|
|
||||||
for (int e = 0; e < (int) y->lengthOf(); e++) {
|
|
||||||
auto dim = y->e<Nd4jLong>(e);
|
|
||||||
if (dim == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(int e2 = 0; e2 < e; e2++){
|
|
||||||
shapeLength *= y->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
|
|
||||||
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= y->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(shapeLength == 0){
|
|
||||||
//Edge case for empty:
|
|
||||||
shapeNew[e] = 0;
|
|
||||||
} else {
|
|
||||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
|
||||||
shapeNew[e] = realShape;
|
|
||||||
}
|
|
||||||
}else {
|
|
||||||
shapeNew[e] = dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
|
||||||
|
orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c';
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
|
||||||
|
|
||||||
|
// Nd4jLong xLen = x->lengthOf();
|
||||||
|
// if(x->isEmpty()) {
|
||||||
|
// xLen = 1;
|
||||||
|
// for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||||
|
// if(x->sizeAt(i) != 0)
|
||||||
|
// xLen *= x->sizeAt(i);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for (uint i = 0; i < reshapeArgs.size(); ++i) {
|
||||||
|
|
||||||
|
// if (reshapeArgs[i] == -1) {
|
||||||
|
|
||||||
|
// uint shapeLength = 1, numOfZeros = 0;
|
||||||
|
|
||||||
|
// for(uint j = 0; j < i; ++j)
|
||||||
|
// if(reshapeArgs[j] != 0)
|
||||||
|
// shapeLength *= reshapeArgs[j];
|
||||||
|
// else
|
||||||
|
// ++numOfZeros;
|
||||||
|
|
||||||
|
// for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
|
||||||
|
// REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
|
// if(reshapeArgs[j] != 0)
|
||||||
|
// shapeLength *= reshapeArgs[j];
|
||||||
|
// else
|
||||||
|
// ++numOfZeros;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const auto dim = xLen / shapeLength;
|
||||||
|
|
||||||
|
// if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
|
||||||
|
// shapeNew.push_back(0);
|
||||||
|
// else
|
||||||
|
// shapeNew.push_back(dim);
|
||||||
|
// }
|
||||||
|
// else
|
||||||
|
// shapeNew.push_back(reshapeArgs[i]);
|
||||||
|
// }
|
||||||
|
|
||||||
|
Nd4jLong newShapeLen = 1;
|
||||||
|
int pos = -1;
|
||||||
|
bool newShapeEmpty = false;
|
||||||
|
|
||||||
|
for (int i = 0; i < reshapeArgs.size(); ++i) {
|
||||||
|
|
||||||
|
const int dim = reshapeArgs[i];
|
||||||
|
|
||||||
|
if (dim == -1) {
|
||||||
|
REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
|
pos = i;
|
||||||
|
shapeNew.push_back(1);
|
||||||
|
}
|
||||||
|
else if (dim == 0) {
|
||||||
|
shapeNew.push_back(0);
|
||||||
|
newShapeEmpty = true;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
shapeNew.push_back(dim);
|
||||||
|
newShapeLen *= dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos != -1) {
|
||||||
|
|
||||||
|
Nd4jLong xLen = x->lengthOf();
|
||||||
|
if(x->isEmpty()) {
|
||||||
|
xLen = 1;
|
||||||
|
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||||
|
if(x->sizeAt(i) > 0 || !newShapeEmpty)
|
||||||
|
xLen *= x->sizeAt(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
shapeNew[pos] = xLen / newShapeLen;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||||
|
REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
||||||
|
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author George A. Shulinok <sgazeos@gmail.com>, created on 4/18/2019.
|
// @author George A. Shulinok <sgazeos@gmail.com>, created on 4/18/2019.
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_barnes_symmetrized)
|
#if NOT_EXCLUDED(OP_barnes_symmetrized)
|
||||||
|
@ -25,20 +25,20 @@
|
||||||
#include <ops/declarable/helpers/BarnesHutTsne.h>
|
#include <ops/declarable/helpers/BarnesHutTsne.h>
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
NDArray* rowCountsPtr = nullptr;
|
NDArray* rowCountsPtr = nullptr;
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) {
|
CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) {
|
||||||
auto rowP = INPUT_VARIABLE(0);
|
auto rowP = INPUT_VARIABLE(0);
|
||||||
auto colP = INPUT_VARIABLE(1);
|
auto colP = INPUT_VARIABLE(1);
|
||||||
auto valP = INPUT_VARIABLE(2);
|
auto valP = INPUT_VARIABLE(2);
|
||||||
auto N = rowP->lengthOf() - 1;
|
auto N = rowP->lengthOf() - 1;
|
||||||
auto outputRows = OUTPUT_VARIABLE(0);
|
auto outputRows = OUTPUT_VARIABLE(0);
|
||||||
auto outputCols = OUTPUT_VARIABLE(1);
|
auto outputCols = OUTPUT_VARIABLE(1);
|
||||||
auto outputVals = OUTPUT_VARIABLE(2);
|
auto outputVals = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
if (block.getIArguments()->size() > 0)
|
if (block.getIArguments()->size() > 0)
|
||||||
N = INT_ARG(0);
|
N = INT_ARG(0);
|
||||||
|
|
||||||
if (rowCountsPtr) {
|
if (rowCountsPtr) {
|
||||||
helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCountsPtr);
|
helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCountsPtr);
|
||||||
|
@ -46,33 +46,33 @@ namespace ops {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data.");
|
return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data.");
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(barnes_symmetrized) {
|
DECLARE_TYPES(barnes_symmetrized) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, {DataType::INT32})
|
->setAllowedInputTypes(0, { DataType::INT32 })
|
||||||
->setAllowedInputTypes(1, {DataType::INT32})
|
->setAllowedInputTypes(1, { DataType::INT32 })
|
||||||
->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS})
|
->setAllowedInputTypes(2, { ALL_INTS, ALL_FLOATS })
|
||||||
->setAllowedOutputTypes(1, {DataType::INT32})
|
->setAllowedOutputTypes(1, { DataType::INT32 })
|
||||||
->setAllowedOutputTypes(1, {DataType::INT32})
|
->setAllowedOutputTypes(1, { DataType::INT32 })
|
||||||
->setAllowedOutputTypes(2, {ALL_INTS, ALL_FLOATS})
|
->setAllowedOutputTypes(2, { ALL_INTS, ALL_FLOATS })
|
||||||
->setSameMode(false);
|
->setSameMode(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(barnes_symmetrized) {
|
DECLARE_SHAPE_FN(barnes_symmetrized) {
|
||||||
auto valPShapeInfo = inputShape->at(2);
|
auto valPShapeInfo = inputShape->at(2);
|
||||||
Nd4jLong* outShapeInfo;
|
Nd4jLong* outShapeInfo;
|
||||||
auto rowP = INPUT_VARIABLE(0);
|
auto rowP = INPUT_VARIABLE(0);
|
||||||
auto colP = INPUT_VARIABLE(1);
|
auto colP = INPUT_VARIABLE(1);
|
||||||
auto N = rowP->lengthOf() - 1;
|
auto N = rowP->lengthOf() - 1;
|
||||||
if (block.getIArguments()->size() > 0)
|
if (block.getIArguments()->size() > 0)
|
||||||
N = INT_ARG(0);
|
N = INT_ARG(0);
|
||||||
auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0));
|
auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0));
|
||||||
NDArray* rowCounts = NDArrayFactory::create_<int>('c', {N}); //rowP->dup();
|
NDArray* rowCounts = NDArrayFactory::create_<int>('c', { N }, block.launchContext()); //rowP->dup();
|
||||||
//srowCounts->assign(0);
|
//srowCounts->assign(0);
|
||||||
Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts);
|
Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts);
|
||||||
rowCounts->syncToHost();
|
rowCounts->syncToHost();
|
||||||
// rowCounts->printBuffer("Row Counts");
|
// rowCounts->printBuffer("Row Counts");
|
||||||
if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len.");
|
if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len.");
|
||||||
rowCountsPtr = rowCounts;
|
rowCountsPtr = rowCounts;
|
||||||
//ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong);
|
//ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong);
|
||||||
|
@ -80,13 +80,13 @@ namespace ops {
|
||||||
// outShapeInfo[2] = len;
|
// outShapeInfo[2] = len;
|
||||||
// ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c');
|
// ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c');
|
||||||
//outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace());
|
//outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace());
|
||||||
outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', {1, len}, block.getWorkspace());
|
outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', { 1, len }, block.getWorkspace());
|
||||||
auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', {1, len}, block.getWorkspace());
|
auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, len }, block.getWorkspace());
|
||||||
auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', {1, N + 1}, block.getWorkspace());
|
auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, N + 1 }, block.getWorkspace());
|
||||||
return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo));
|
return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -22,6 +22,7 @@
|
||||||
#define LIBND4J_HEADERS_BROADCASTABLE_H
|
#define LIBND4J_HEADERS_BROADCASTABLE_H
|
||||||
|
|
||||||
#include <ops/declarable/BroadcastableOp.h>
|
#include <ops/declarable/BroadcastableOp.h>
|
||||||
|
#include <ops/declarable/BroadcastableBoolOp.h>
|
||||||
#include <ops/declarable/headers/common.h>
|
#include <ops/declarable/headers/common.h>
|
||||||
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||||
|
|
||||||
|
@ -261,7 +262,7 @@ namespace sd {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_equals)
|
#if NOT_EXCLUDED(OP_equals)
|
||||||
DECLARE_BROADCASTABLE_OP(equals, 0, 0);
|
DECLARE_BROADCASTABLE_BOOL_OP(equals, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -269,7 +270,7 @@ namespace sd {
|
||||||
* Math is: _x != _y ? (T) 1.0f : (T) 0.0f;
|
* Math is: _x != _y ? (T) 1.0f : (T) 0.0f;
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_not_equals)
|
#if NOT_EXCLUDED(OP_not_equals)
|
||||||
DECLARE_BROADCASTABLE_OP(not_equals, 0, 0);
|
DECLARE_BROADCASTABLE_BOOL_OP(not_equals, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -277,7 +278,7 @@ namespace sd {
|
||||||
* Math is: _x <= _y ? (T) 1.0f : (T) 0.0f;
|
* Math is: _x <= _y ? (T) 1.0f : (T) 0.0f;
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_less_equal)
|
#if NOT_EXCLUDED(OP_less_equal)
|
||||||
DECLARE_BROADCASTABLE_OP(less_equal, 0, 0);
|
DECLARE_BROADCASTABLE_BOOL_OP(less_equal, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -285,7 +286,7 @@ namespace sd {
|
||||||
* Math is: _x >= _y ? (T) 1.0f : (T) 0.0f;
|
* Math is: _x >= _y ? (T) 1.0f : (T) 0.0f;
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_greater_equal)
|
#if NOT_EXCLUDED(OP_greater_equal)
|
||||||
DECLARE_BROADCASTABLE_OP(greater_equal, 0, 0);
|
DECLARE_BROADCASTABLE_BOOL_OP(greater_equal, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -293,7 +294,7 @@ namespace sd {
|
||||||
* Math is: _x < _y ? (T) 1.0f : (T) 0.0f;
|
* Math is: _x < _y ? (T) 1.0f : (T) 0.0f;
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_less)
|
#if NOT_EXCLUDED(OP_less)
|
||||||
DECLARE_BROADCASTABLE_OP(less, 0, 0);
|
DECLARE_BROADCASTABLE_BOOL_OP(less, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -301,7 +302,7 @@ namespace sd {
|
||||||
* Math is: _x > _y ? (T) 1.0f : (T) 0.0f;
|
* Math is: _x > _y ? (T) 1.0f : (T) 0.0f;
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_greater)
|
#if NOT_EXCLUDED(OP_greater)
|
||||||
DECLARE_BROADCASTABLE_OP(greater, 0, 0);
|
DECLARE_BROADCASTABLE_BOOL_OP(greater, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -867,9 +867,12 @@ namespace sd {
|
||||||
* - 2D matrix MxN
|
* - 2D matrix MxN
|
||||||
* - 1D vector with N elements
|
* - 1D vector with N elements
|
||||||
* output value - 2D matrix NxN as multiply of matrixes and add vector
|
* output value - 2D matrix NxN as multiply of matrixes and add vector
|
||||||
|
* Int args:
|
||||||
|
* 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_xw_plus_b)
|
#if NOT_EXCLUDED(OP_xw_plus_b)
|
||||||
DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0);
|
||||||
|
DECLARE_CUSTOM_OP(xw_plus_b_bp, 4, 3, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -149,6 +149,13 @@ namespace ops {
|
||||||
DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2);
|
DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if NOT_EXCLUDED(OP_lstmLayerCell)
|
||||||
|
DECLARE_CUSTOM_OP(lstmLayerCell, 5, 2, false, 1, 3);
|
||||||
|
#endif
|
||||||
|
#if NOT_EXCLUDED(OP_lstmLayerCell)
|
||||||
|
DECLARE_CUSTOM_OP(lstmLayerCellBp, 7, 5, false, 1, 3);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
/**
|
/**
|
||||||
|
@ -236,6 +243,11 @@ namespace ops {
|
||||||
DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5);
|
DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
#if NOT_EXCLUDED(OP_lstmLayer)
|
||||||
|
DECLARE_CUSTOM_OP(lstmLayer_bp, 4, 1, false, 1, 5);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -142,7 +142,7 @@ namespace helpers {
|
||||||
const int rowNum = input->rows();
|
const int rowNum = input->rows();
|
||||||
const int columnNum = input->columns();
|
const int columnNum = input->columns();
|
||||||
|
|
||||||
NDArray determinant = NDArrayFactory::create<T>(1.f);
|
NDArray determinant = NDArrayFactory::create<T>(1.f, context);
|
||||||
NDArray compoundMatrix = *input; // copy
|
NDArray compoundMatrix = *input; // copy
|
||||||
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
|
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
|
||||||
permutationMatrix.setIdentity();
|
permutationMatrix.setIdentity();
|
||||||
|
|
|
@ -39,7 +39,7 @@ namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
NDArray vmul(NDArray const& v, int n)
|
NDArray vmul(NDArray const& v, int n)
|
||||||
{
|
{
|
||||||
NDArray res('c', {n,n}, v.dataType()); // x = matrix_new(n, n);
|
NDArray res('c', {n,n}, v.dataType(), v.getContext()); // x = matrix_new(n, n);
|
||||||
T const* vBuf = v.getDataBuffer()->primaryAsT<T>();
|
T const* vBuf = v.getDataBuffer()->primaryAsT<T>();
|
||||||
T* resBuf = res.dataBuffer()->primaryAsT<T>();
|
T* resBuf = res.dataBuffer()->primaryAsT<T>();
|
||||||
auto interloop = PRAGMA_THREADS_FOR_2D {
|
auto interloop = PRAGMA_THREADS_FOR_2D {
|
||||||
|
@ -61,7 +61,7 @@ namespace helpers {
|
||||||
std::vector<NDArray> q(M);
|
std::vector<NDArray> q(M);
|
||||||
|
|
||||||
NDArray z = *matrix;
|
NDArray z = *matrix;
|
||||||
NDArray e('c', {M}, DataTypeUtils::fromT<T>()); // two internal buffers and scalar for squared norm
|
NDArray e('c', {M}, DataTypeUtils::fromT<T>(), Q->getContext()); // two internal buffers and scalar for squared norm
|
||||||
|
|
||||||
for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
|
for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
|
||||||
e.nullify();
|
e.nullify();
|
||||||
|
|
|
@ -69,9 +69,9 @@ namespace helpers {
|
||||||
auto trial = (*input)(e, dimsToExclude);
|
auto trial = (*input)(e, dimsToExclude);
|
||||||
|
|
||||||
// fill up the first k elements
|
// fill up the first k elements
|
||||||
NDArray topValues = NDArrayFactory::create<T>('c', {k});
|
NDArray topValues = NDArrayFactory::create<T>('c', {k}, input->getContext());
|
||||||
NDArray sortedVals = NDArrayFactory::create<T>('c', {k});
|
NDArray sortedVals = NDArrayFactory::create<T>('c', {k}, input->getContext());
|
||||||
NDArray topIndices = NDArrayFactory::create<Nd4jLong>('c', {k});
|
NDArray topIndices = NDArrayFactory::create<Nd4jLong>('c', {k}, input->getContext());
|
||||||
for (uint pos = 0; pos < k; ++pos) {
|
for (uint pos = 0; pos < k; ++pos) {
|
||||||
topIndices.t<Nd4jLong>(pos) = pos;
|
topIndices.t<Nd4jLong>(pos) = pos;
|
||||||
topValues.t<T>(pos) = trial.t<T>(pos);
|
topValues.t<T>(pos) = trial.t<T>(pos);
|
||||||
|
@ -144,7 +144,7 @@ namespace helpers {
|
||||||
for (int i = 0; i < input->rankOf() - 1; i++)
|
for (int i = 0; i < input->rankOf() - 1; i++)
|
||||||
shapeI[i] = input->sizeAt(i);
|
shapeI[i] = input->sizeAt(i);
|
||||||
shapeI[input->rankOf() - 1] = k;
|
shapeI[input->rankOf() - 1] = k;
|
||||||
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<Nd4jLong>(input->ordering(), shapeI));
|
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<Nd4jLong>(input->ordering(), shapeI, context));
|
||||||
NDArray* values = nullptr;
|
NDArray* values = nullptr;
|
||||||
int status = topKFunctor(context, input, values, indices.get(), k, true);
|
int status = topKFunctor(context, input, values, indices.get(), k, true);
|
||||||
result->assign(0);
|
result->assign(0);
|
||||||
|
|
|
@ -112,7 +112,7 @@ namespace sd {
|
||||||
int numThreads = 256;
|
int numThreads = 256;
|
||||||
int numBlocks = sd::math::nd4j_max<int>(256, sd::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
|
int numBlocks = sd::math::nd4j_max<int>(256, sd::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
|
||||||
int workspaceSize = numBlocks * numBins;
|
int workspaceSize = numBlocks * numBins;
|
||||||
auto tmp = NDArrayFactory::create<Z>('c', {workspaceSize});
|
auto tmp = NDArrayFactory::create<Z>('c', {workspaceSize}, context);
|
||||||
|
|
||||||
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast<X*>(min_val), reinterpret_cast<X*>(max_val));
|
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast<X*>(min_val), reinterpret_cast<X*>(max_val));
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
typedef NDArray ColorTable_t;
|
typedef NDArray ColorTable_t;
|
||||||
static NDArray DefaultColorTable(int depth) {
|
static NDArray DefaultColorTable(int depth, sd::LaunchContext* context) {
|
||||||
//std::vector<std::vector<float>> colorTable;
|
//std::vector<std::vector<float>> colorTable;
|
||||||
const Nd4jLong kDefaultTableLength = 10;
|
const Nd4jLong kDefaultTableLength = 10;
|
||||||
const Nd4jLong kDefaultChannelLength = 4;
|
const Nd4jLong kDefaultChannelLength = 4;
|
||||||
|
@ -40,7 +40,7 @@ namespace helpers {
|
||||||
0, 0, 0.5, 1, // 7: navy blue
|
0, 0, 0.5, 1, // 7: navy blue
|
||||||
0, 1, 1, 1, // 8: aqua
|
0, 1, 1, 1, // 8: aqua
|
||||||
1, 0, 1, 1 // 9: fuchsia
|
1, 0, 1, 1 // 9: fuchsia
|
||||||
}, DataType::FLOAT32);
|
}, DataType::FLOAT32, context);
|
||||||
|
|
||||||
if (depth == 1) {
|
if (depth == 1) {
|
||||||
colorTable.assign(1.f); // all to white when black and white colors
|
colorTable.assign(1.f); // all to white when black and white colors
|
||||||
|
@ -144,7 +144,7 @@ namespace helpers {
|
||||||
auto channels = images->sizeAt(3);
|
auto channels = images->sizeAt(3);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
auto boxSize = boxes->sizeAt(1);
|
auto boxSize = boxes->sizeAt(1);
|
||||||
NDArray colorsTable = DefaultColorTable(channels);
|
NDArray colorsTable = DefaultColorTable(channels, context);
|
||||||
if ((colors != nullptr && colors->lengthOf() > 0)) {
|
if ((colors != nullptr && colors->lengthOf() > 0)) {
|
||||||
colorsTable = *colors;
|
colorsTable = *colors;
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,7 +188,7 @@ namespace helpers {
|
||||||
static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) {
|
static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {boxes, scales});
|
NDArray::prepareSpecialUse({output}, {boxes, scales});
|
||||||
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext());
|
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()}, context)); // - 1, scales->lengthOf()); //, scales->getContext());
|
||||||
|
|
||||||
NDArray scores(*scales);
|
NDArray scores(*scales);
|
||||||
Nd4jPointer extras[2] = {nullptr, stream};
|
Nd4jPointer extras[2] = {nullptr, stream};
|
||||||
|
@ -198,7 +198,7 @@ namespace helpers {
|
||||||
indices->tickWriteDevice();
|
indices->tickWriteDevice();
|
||||||
sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
||||||
indices->tickWriteDevice();
|
indices->tickWriteDevice();
|
||||||
NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()});
|
NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()}, context);
|
||||||
int numSelected = 0;
|
int numSelected = 0;
|
||||||
int numBoxes = boxes->sizeAt(0);
|
int numBoxes = boxes->sizeAt(0);
|
||||||
auto boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer());
|
auto boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer());
|
||||||
|
@ -347,8 +347,8 @@ namespace helpers {
|
||||||
scores->syncToDevice();
|
scores->syncToDevice();
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArray indices = NDArrayFactory::create<I>('c', {scores->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext());
|
NDArray indices = NDArrayFactory::create<I>('c', {scores->lengthOf()}, context); // - 1, scales->lengthOf()); //, scales->getContext());
|
||||||
NDArray startPositions = NDArrayFactory::create<I>('c', {scores->lengthOf()});
|
NDArray startPositions = NDArrayFactory::create<I>('c', {scores->lengthOf()}, context);
|
||||||
NDArray selectedScores(*scores);
|
NDArray selectedScores(*scores);
|
||||||
Nd4jPointer extras[2] = {nullptr, stream};
|
Nd4jPointer extras[2] = {nullptr, stream};
|
||||||
auto indexBuf = indices.dataBuffer()->specialAsT<I>();///reinterpret_cast<I*>(indices->specialBuffer());
|
auto indexBuf = indices.dataBuffer()->specialAsT<I>();///reinterpret_cast<I*>(indices->specialBuffer());
|
||||||
|
|
|
@ -598,7 +598,7 @@ namespace helpers {
|
||||||
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
|
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray iota('c', {n}, permutationVectors->dataType());// = NDArrayFactory::create(); // <int>('c', {n});
|
NDArray iota('c', {n}, permutationVectors->dataType(), context);// = NDArrayFactory::create(); // <int>('c', {n});
|
||||||
iota.linspace(0); iota.syncToDevice();
|
iota.linspace(0); iota.syncToDevice();
|
||||||
|
|
||||||
output->assign(input); // fill up output tensor with zeros
|
output->assign(input); // fill up output tensor with zeros
|
||||||
|
@ -631,7 +631,7 @@ namespace helpers {
|
||||||
// if (dtype != DataType::DOUBLE)
|
// if (dtype != DataType::DOUBLE)
|
||||||
// dtype = DataType::FLOAT32;
|
// dtype = DataType::FLOAT32;
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||||
auto det = NDArrayFactory::create<T>(1);
|
auto det = NDArrayFactory::create<T>(1, context);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
dim3 launchDims(256, 256, 1024);
|
dim3 launchDims(256, 256, 1024);
|
||||||
|
@ -677,7 +677,7 @@ namespace helpers {
|
||||||
dtype = DataType::FLOAT32;
|
dtype = DataType::FLOAT32;
|
||||||
|
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
|
||||||
auto det = NDArrayFactory::create<T>(1);
|
auto det = NDArrayFactory::create<T>(1, context);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
dim3 launchDims(256, 256, 1024);
|
dim3 launchDims(256, 256, 1024);
|
||||||
|
|
|
@ -110,7 +110,7 @@ namespace helpers {
|
||||||
auto resR = fullMatricies?R->ulike():matrix->ulike();
|
auto resR = fullMatricies?R->ulike():matrix->ulike();
|
||||||
std::vector<NDArray> q(M);
|
std::vector<NDArray> q(M);
|
||||||
NDArray z = *matrix;
|
NDArray z = *matrix;
|
||||||
NDArray e('c', {M}, DataTypeUtils::fromT<T>()); // two internal buffers and scalar for squared norm
|
NDArray e('c', {M}, DataTypeUtils::fromT<T>(), context); // two internal buffers and scalar for squared norm
|
||||||
for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
|
for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
|
||||||
e.nullify();
|
e.nullify();
|
||||||
z = matrixMinor<T>(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix)
|
z = matrixMinor<T>(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix)
|
||||||
|
@ -177,4 +177,3 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -167,8 +167,8 @@ namespace sd {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
indices->syncToHost();
|
indices->syncToHost();
|
||||||
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -209,8 +209,8 @@ namespace sd {
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
output->assign(DataTypeUtils::infOrMax<T>());
|
output->assign(DataTypeUtils::infOrMax<T>());
|
||||||
|
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
|
|
@ -158,8 +158,8 @@ namespace helpers {
|
||||||
static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -198,8 +198,8 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
|
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
@ -314,8 +314,8 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -367,8 +367,8 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
|
|
@ -161,8 +161,8 @@ namespace helpers {
|
||||||
static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
output->assign(DataTypeUtils::infOrMax<T>());
|
output->assign(DataTypeUtils::infOrMax<T>());
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -202,8 +202,8 @@ namespace helpers {
|
||||||
static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
output->assign(DataTypeUtils::infOrMax<T>());
|
output->assign(DataTypeUtils::infOrMax<T>());
|
||||||
|
|
|
@ -122,8 +122,8 @@ namespace helpers {
|
||||||
static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
output->assign(1);
|
output->assign(1);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -160,8 +160,8 @@ namespace helpers {
|
||||||
static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
|
|
@ -86,8 +86,8 @@ namespace helpers {
|
||||||
static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
@ -207,8 +207,8 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
|
|
@ -162,8 +162,8 @@ namespace helpers {
|
||||||
static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -201,8 +201,8 @@ namespace helpers {
|
||||||
static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -41,7 +41,7 @@ namespace helpers {
|
||||||
length += array->lengthOf();
|
length += array->lengthOf();
|
||||||
pos++;
|
pos++;
|
||||||
}
|
}
|
||||||
NDArray arrayFull('c', {length}, sd::DataType::INT32);
|
NDArray arrayFull('c', {length}, sd::DataType::INT32, inputList[0]->getContext());
|
||||||
cContext.setOutputArray(0, &arrayFull);
|
cContext.setOutputArray(0, &arrayFull);
|
||||||
cContext.setIArguments(&axis, 1);
|
cContext.setIArguments(&axis, 1);
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
#define LIBND4J_LSTMLAYER_H
|
#define LIBND4J_LSTMLAYER_H
|
||||||
|
|
||||||
#include <ops/declarable/helpers/helpers.h>
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
#include <ops/declarable/helpers/activations.h>
|
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -34,6 +33,20 @@ void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArra
|
||||||
const std::vector<float>& params,
|
const std::vector<float>& params,
|
||||||
NDArray* h, NDArray* c);
|
NDArray* h, NDArray* c);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// this auxiliary ff should be running before backprop
|
||||||
|
void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||||
|
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||||||
|
const std::vector<float>& params,
|
||||||
|
NDArray* z, NDArray* a, NDArray* h, NDArray* c);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||||||
|
const NDArray* dLdh, const NDArray* dLdc,
|
||||||
|
const NDArray* z, const NDArray* a, const NDArray* c, const std::vector<float>& params,
|
||||||
|
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp);
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||||
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||||||
|
@ -42,71 +55,11 @@ void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const ND
|
||||||
NDArray* h, NDArray* hL, NDArray* cL);
|
NDArray* h, NDArray* hL, NDArray* cL);
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) {
|
void ND4J_EXPORT lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||||
|
const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp,
|
||||||
switch (opId) {
|
const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
|
||||||
case 0:
|
const std::vector<float>& params, const bool forward,
|
||||||
(const_cast<NDArray&>(x)).applyTransform(transform::Tanh, z);
|
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp);
|
||||||
break;
|
|
||||||
case 1:
|
|
||||||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELU, 0, z);
|
|
||||||
break;
|
|
||||||
case 2:
|
|
||||||
(const_cast<NDArray&>(x)).applyTransform(transform::Sigmoid, z);
|
|
||||||
break;
|
|
||||||
case 3: {
|
|
||||||
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
|
|
||||||
(const_cast<NDArray&>(x)).applyTransform(transform::Affine, z, &args);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 4:
|
|
||||||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELU, alpha, z);
|
|
||||||
break;
|
|
||||||
case 5:
|
|
||||||
helpers::thresholdRelu(x.getContext(), x, alpha, z);
|
|
||||||
break;
|
|
||||||
case 6: {
|
|
||||||
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
|
|
||||||
(const_cast<NDArray&>(x)).applyTransform(transform::ScaledTanh, z, &args);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 7:
|
|
||||||
(const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoid, z);
|
|
||||||
break;
|
|
||||||
case 8:
|
|
||||||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELU, alpha, z);
|
|
||||||
break;
|
|
||||||
case 9:
|
|
||||||
(const_cast<NDArray&>(x)).applyTransform(transform::SoftSign, z);
|
|
||||||
break;
|
|
||||||
case 10:
|
|
||||||
(const_cast<NDArray&>(x)).applyTransform(transform::SoftPlus, z);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
static FORCEINLINE NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) {
|
|
||||||
|
|
||||||
if(dataFormat == 0 || dataFormat == 3)
|
|
||||||
return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn]
|
|
||||||
|
|
||||||
if(dataFormat == 1)
|
|
||||||
return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn]
|
|
||||||
|
|
||||||
return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL]
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) {
|
|
||||||
|
|
||||||
if(dataFormat == 0 || dataFormat == 3)
|
|
||||||
return t * bS + b; // TNS: shape [sL, bS, nIn]
|
|
||||||
|
|
||||||
return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver on 6/6/2018.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <system/pointercast.h>
|
||||||
|
#include <ops/declarable/BroadcastableBoolOp.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
BroadcastableBoolOp::BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs) : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, numIArgs) {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapeList *BroadcastableBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) {
|
||||||
|
auto shapeList = SHAPELIST();
|
||||||
|
auto x = inputShape->at(0);
|
||||||
|
auto y = inputShape->at(1);
|
||||||
|
sd::DataType dtype = sd::DataType::BOOL;
|
||||||
|
|
||||||
|
if(shape::isEmpty(x) || shape::isEmpty(y)) {
|
||||||
|
// this is edge case, [3, 4] + [] = []
|
||||||
|
if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) {
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)));
|
||||||
|
return shapeList;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong *newshape = nullptr;
|
||||||
|
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
|
||||||
|
} else if (shape::isScalar(x) && shape::isScalar(y)) {
|
||||||
|
if (shape::rank(x) >= shape::rank(y)) {
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
||||||
|
} else {
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype)));
|
||||||
|
}
|
||||||
|
} else if (shape::equalsSoft(x, y)) {
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
||||||
|
} else if (shape::isScalar(x) && !shape::isScalar(y)) {
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype)));
|
||||||
|
} else if (!shape::isScalar(x) && shape::isScalar(y)) {
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
||||||
|
} else if (ShapeUtils::areShapesBroadcastable(x, y)) {
|
||||||
|
Nd4jLong *newshape = nullptr;
|
||||||
|
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
|
||||||
|
} else {
|
||||||
|
// in this case we'll throw exception later
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return shapeList;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -368,7 +368,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
||||||
REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !");
|
REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !");
|
||||||
REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!");
|
REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!");
|
||||||
REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !");
|
REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !");
|
||||||
REQUIRE_TRUE((retLastH && retLastC) || (!retLastH && !retLastC), 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !");
|
REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !");
|
||||||
|
|
||||||
count = 0;
|
count = 0;
|
||||||
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
|
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
|
||||||
|
@ -464,13 +464,21 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
||||||
}
|
}
|
||||||
|
|
||||||
PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
|
PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
|
||||||
|
|
||||||
|
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
|
||||||
|
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
|
||||||
|
|
||||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||||
|
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||||
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||||
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||||
|
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||||
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||||
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||||
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||||
|
|
||||||
|
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||||||
|
|
||||||
const auto x = INPUT_VARIABLE(0); // input
|
const auto x = INPUT_VARIABLE(0); // input
|
||||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||||
|
@ -495,7 +503,15 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
|
||||||
DataType hLType = hL != nullptr ? hL->dataType() : xType;
|
DataType hLType = hL != nullptr ? hL->dataType() : xType;
|
||||||
DataType cLType = cL != nullptr ? cL->dataType() : xType;
|
DataType cLType = cL != nullptr ? cL->dataType() : xType;
|
||||||
|
|
||||||
return block.isUseMKLDNN() && (
|
auto featuresSupported = (cellClip == 0) //Cell clipping not supported
|
||||||
|
&& retFullSeq //Always return full sequence in case of MKL DNN
|
||||||
|
&& !hasPH //Peephole connections not supported in MKL DNN
|
||||||
|
&& !hasSeqLen //Sequence length array not supported in MKL DNN
|
||||||
|
&& dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn]
|
||||||
|
&& directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat
|
||||||
|
&& retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other)
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && featuresSupported && (
|
||||||
(xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) ||
|
(xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) ||
|
||||||
(xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) ||
|
(xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) ||
|
||||||
(xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8))
|
(xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8))
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY) {
|
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) {
|
||||||
|
|
||||||
// mkl works with following
|
// mkl works with following
|
||||||
// [M,K] x [K,N] = [M,N]
|
// [M,K] x [K,N] = [M,N]
|
||||||
|
@ -150,6 +150,12 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
||||||
|
|
||||||
// Create attributes (to handle alpha and beta if necessary)
|
// Create attributes (to handle alpha and beta if necessary)
|
||||||
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
|
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
|
||||||
|
if (alpha != 1.f) attr.set_output_scales(0, {alpha});
|
||||||
|
if (beta != 0.f) {
|
||||||
|
dnnl::post_ops po;
|
||||||
|
po.append_sum(beta);
|
||||||
|
attr.set_post_ops(po);
|
||||||
|
}
|
||||||
|
|
||||||
// operation primitive description
|
// operation primitive description
|
||||||
dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md);
|
dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md);
|
||||||
|
@ -224,11 +230,16 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) {
|
||||||
if(x->isEmpty() || y->isEmpty())
|
if(x->isEmpty() || y->isEmpty())
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
const int iSize = (int) block.getIArguments()->size();
|
int iSize = (int) block.getIArguments()->size();
|
||||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||||
|
|
||||||
|
// optional use alpha nad beta
|
||||||
|
iSize = (int)block.getTArguments()->size();
|
||||||
|
float alpha = iSize > 0 ? T_ARG(0) : 1.0;
|
||||||
|
float beta = iSize > 1 ? T_ARG(1) : 0.0;
|
||||||
|
|
||||||
const int xRank = x->rankOf();
|
const int xRank = x->rankOf();
|
||||||
const int yRank = y->rankOf();
|
const int yRank = y->rankOf();
|
||||||
const int zRank = z->rankOf();
|
const int zRank = z->rankOf();
|
||||||
|
@ -265,7 +276,7 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) {
|
||||||
}
|
}
|
||||||
// ******* end of input validation ******* //
|
// ******* end of input validation ******* //
|
||||||
|
|
||||||
matmulMKLDNN(x, y, z, transX, transY);
|
matmulMKLDNN(x, y, z, transX, transY, alpha, beta);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -276,14 +287,16 @@ PLATFORM_CHECK(matmul, ENGINE_CPU) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
auto z = INPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const DataType xType = x->dataType();
|
const DataType xType = x->dataType();
|
||||||
const DataType yType = y->dataType();
|
const DataType yType = y->dataType();
|
||||||
const DataType zType = z->dataType();
|
const DataType zType = z->dataType();
|
||||||
|
|
||||||
|
float alpha = block.numT() > 0 ? T_ARG(0) : 1.0;
|
||||||
|
float beta = block.numT() > 1 ? T_ARG(1) : 0.0;
|
||||||
|
|
||||||
return block.isUseMKLDNN() && x->rankOf() < 3 &&
|
return !(z->ordering() == 'f' && beta != 0.f) && block.isUseMKLDNN() && x->rankOf() < 3 &&
|
||||||
(
|
(
|
||||||
(xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
(xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||||
(xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) ||
|
(xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) ||
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue